Skip to content

feat: YAML-driven torch op codegen with canonical naming and exposed semantic params#595

Open
voltjia wants to merge 6 commits into
masterfrom
feat/torch-codegen
Open

feat: YAML-driven torch op codegen with canonical naming and exposed semantic params#595
voltjia wants to merge 6 commits into
masterfrom
feat/torch-codegen

Conversation

@voltjia
Copy link
Copy Markdown
Collaborator

@voltjia voltjia commented May 9, 2026

Summary

  • Add scripts/generate_torch_ops.py (~920 lines) — a YAML-driven codegen that consumes PyTorch's aten/native_functions.yaml and emits an InfiniOps base class plus a slot-8 PyTorch backend per op listed in scripts/torch_ops.yaml (~459 ops, generating 507 overloads across 437 canonical classes).
  • Wire the codegen into CMake (src/CMakeLists.txt) under WITH_TORCH=ON: invoke at configure time, glob generated/torch/*.cc, add generated/ to public include paths, install the per-op metadata JSON alongside the bindings.
  • Update the wrapper generator (scripts/generate_wrappers.py) to scan generated/base/ and generated/torch/, fix pybind11 overload ordering (specific → permissive), preserve std::vector<int64_t> parameters that libclang misreports as int, and route active_implementation_indices through a graceful unknown-device path.
  • Add safe device-type lookup primitives (detail::ListContains in src/operator.h, TryDeviceTypeFromString in src/pybind11_utils.h) so generated bindings handle devices an op does not implement without aborting.
  • Add a single data-driven tests/test_torch_ops.py that reads generated/torch_ops_metadata.json and exercises every generated op across three shapes and three dtypes; widen tests/conftest.py to handle non-floating outputs and equal_nan.
  • Move the Sigmoid helper in src/native/cuda/ops/swiglu/kernel.cuh into detail:: so it does not collide with the auto-generated infini::ops::Sigmoid operator class.
  • Add pyyaml to [build-system].requires so CMake can run the codegen during pip install.

Codegen design choices driven by review feedback collected across all 513 base PRs against feat/torch-codegen:

  • Canonical names only. ATen overload-name suffixes (_grad_input, _outtensor, _n_scalar, _values, _x, _l, _q, _u, _output) no longer leak into InfiniOps class names. Multiple ATen overloads of the same base op share a single class, with overloaded operator() methods.
  • Visible scalars are members. Every visible non-tensor parameter (scalars, strings, vectors) is stored as a base-class member initialized from the constructor argument, alongside the existing tensor-metadata members.
  • Default-valued non-optional params are exposed. bool upper, bool transpose, bool unitriangular (triangular_solve), int diagonal (triu), str ord (linalg_matrix_norm), int n on the chebyshev/hermite polynomial families, etc. are no longer hidden because they have an ATen default — they are now visible in the generated operator() and forwarded to ATen.

Motivation

Replaces 500+ hand-written src/base/<op>.h headers with a single declarative pipeline driven from PyTorch's schema. Each commit is single-purpose and individually passes ruff check, ruff format --check, and clang-format (version 21).

The previous iteration (feat/torch-codegen-legacy, preserved on the remote) generated suffixed names that reviewers consistently flagged as bad public API (per inline comments on PRs #280, #283-#290, #509, #563-#589). It also hid semantically critical parameters and did not store scalars as members, requiring hand-written corrections in every base PR. This refactor moves those corrections into the codegen itself, so future regenerations produce the reviewer-preferred shape directly. The 77 PRs that were for non-canonical overload names have been closed; the remaining 333 keep + 103 promote PRs have their content regenerated to match the new codegen output.

Closes #

Type of Change

  • feat — new feature / new operator / new platform
    N/A: fix — this is a feature PR rather than a bug-fix-only PR.
    N/A: perf — no runtime hot-path performance change is intended.
    N/A: refactor — the primary change is generated PyTorch backend support.
    N/A: test — this PR adds tests, but it is not a test-only PR.
    N/A: docs — documentation-only PR does not apply.
  • build / ci — build system or CI configuration
    N/A: chore — tooling-only PR does not apply.
    N/A: Breaking change (requires a ! in the Conventional Commits prefix or a BREAKING CHANGE: footer).

Platforms Affected

  • CPU (WITH_CPU)
  • NVIDIA (WITH_NVIDIA)
  • Iluvatar (WITH_ILUVATAR)
  • MetaX (WITH_METAX)
  • Cambricon (WITH_CAMBRICON)
  • Moore (WITH_MOORE)
  • Ascend (WITH_ASCEND)
  • PyTorch C++ bindings (WITH_TORCH)
  • Build system / CMake / CI
  • Python bindings / user-facing API

Test Results on Supported Platforms

All supported platforms were validated with the direct profile. The generated PyTorch backend tests are included in the collected pytest set on every platform below.

Platform Built pytest Result Notes / Hardware
NVIDIA Yes 9147 passed, 7462 skipped, 81 warnings in 373.45s (0:06:13) direct profile; 16609 items collected; PyTorch backend tests included
Iluvatar Yes 7371 passed, 7738 skipped, 81 warnings in 377.58s (0:06:17) direct profile; 15109 items collected; native and PyTorch backend tests included
MetaX Yes 8639 passed, 6470 skipped, 81 warnings in 382.57s (0:06:22) direct profile; 15109 items collected; PyTorch backend tests included
Cambricon Yes 5852 passed, 8891 skipped, 169 warnings in 996.13s (0:16:36) direct profile; 14743 items collected; PyTorch backend tests included
Moore Yes 8422 passed, 6687 skipped, 99 warnings in 600.34s (0:10:00) direct profile; 15109 items collected; PyTorch backend tests included
Ascend Yes 7344 passed, 7705 skipped, 98 warnings in 573.57s (0:09:33) direct profile; 15049 items collected; pytest completed successfully; outer container returned 137 after pytest summary
Full `pytest` output (optional)
NVIDIA: 16609 items; 9147 passed, 7462 skipped, 81 warnings in 373.45s (0:06:13)
Iluvatar: 15109 items; 7371 passed, 7738 skipped, 81 warnings in 377.58s (0:06:17)
MetaX: 15109 items; 8639 passed, 6470 skipped, 81 warnings in 382.57s (0:06:22)
Cambricon: 14743 items; 5852 passed, 8891 skipped, 169 warnings in 996.13s (0:16:36)
Moore: 15109 items; 8422 passed, 6687 skipped, 99 warnings in 600.34s (0:10:00)
Ascend: 15049 items; 7344 passed, 7705 skipped, 98 warnings in 573.57s (0:09:33)

Benchmark / Performance Impact

N/A. This PR adds a codegen pipeline, not a runtime hot-path change. Generated PyTorch backends call at::<op>_out(...) directly, so per-op performance matches a hand-written ATen-backed op.

Notes for Reviewers

  • The branch was force-pushed over the previous feat/torch-codegen integration branch. The previous content is preserved at feat/torch-codegen-legacy for reference.
  • The 513 open base PRs against feat/torch-codegen have been processed: 77 redundant overload PRs closed, 333 keep + 103 promote PRs scheduled to be force-pushed with regenerated content matching the new canonical naming and parameter shape.
  • Slot 8 is reserved for PyTorch backends; native and vendor implementations claim slots 0–7. The slot must be > 0 to avoid a partial-specialization-after-instantiation conflict with Operator<Op> at index 0.
  • Hand-written src/base/<op>.h continues to shadow generated/base/<op>.h (existence-based; no signature compatibility check). The four pre-existing hand-written bases that do not match the ATen-derived signature (add, linear, matmul, mul) are excluded from scripts/torch_ops.yaml and left to their existing hand-written infrastructure.
  • Optional ATen types (Tensor?, Scalar?, int?, float?) remain hidden for now — exposing them properly requires threading std::optional through to ATen, which is a separable refactor.

Checklist

Every contributor must verify every item below before requesting review. Tick each box only after the check has actually been performed — do not tick speculatively. If an item truly does not apply, replace the checkbox with N/A and briefly explain why in an inline comment.

Title, Branch, and Commits

  • PR title follows Conventional Commits.
  • Branch name follows <type>/xxx-yyyy-zzzzfeat/torch-codegen.
  • Each commit message follows Conventional Commits.
  • Large PR with meaningful, well-formed, independently reviewable commits (11 commits, each one logical change).
  • No stray merge commits from master — branch is rebased cleanly on top of current master.
  • No fixup! / squash! / wip commits remain.

Scope and Design

  • Changes are minimal for a codegen-introduction PR; non-essential workflow tooling (merge_base_branches.py from the legacy branch) was dropped.
  • No dead code, commented-out blocks, debug prints, or TODO without an owner.
  • No unrelated formatting churn that would obscure the diff.
  • Public API changes are intentional: the slot-8 dispatch path and the new infini::ops::<Pascal> classes are documented via the codegen's docstring.

General Code Hygiene (applies to all languages)

  • The code is self-explanatory; comments were added only where the why is non-obvious.
  • Every modified or added file ends with a single trailing newline.
  • No trailing whitespace, tab/space mixing, or stray BOMs.
  • Identifiers in comments and error messages are wrapped in backticks.
  • All comments and error messages are in English.
  • Comments and error messages are complete sentences with terminal punctuation.

C++ Specific (if C++ files changed)

  • Code follows the Google C++ Style Guide strictly.
  • clang-format (version 21) is clean for every modified .h, .cc, .cuh file. Verified via git rebase master --exec 'clang-format --dry-run --Werror $(git diff HEAD~1 --name-only -- "*.h" "*.cc" "*.cuh")'.
    N/A: clang-tidy was not run; this repository's PR validation for this generated C++ surface is clang-format, full cross-platform build, and full cross-platform pytest.
  • Operator parameter order: inputs first, attributes between, outputs last (matches ATen _out form).
  • No exceptions — error paths use assert. Generated code uses ATen which itself uses TORCH_CHECK; that is consistent with the existing torch backend pattern.
  • Error/warning wording follows LLVM Coding Standards.
  • N/A: no new kernel files added — codegen only emits ATen wrappers.
  • Constructor initializer list order matches member declaration order (verified by inspecting generated headers).
  • One blank line between classes/functions; one between class members; one before/after namespace contents.
  • New operators added via src/base/<op>.h (auto-generated under generated/base/) inheriting Operator<Op>; PyTorch backend specializes at slot 8.
  • No raw new/delete.

Python Specific (if Python files changed)

  • ruff check is clean for the entire repo.
  • ruff format --check is clean for the entire repo. Verified per-commit via git rebase master --exec 'ruff format --check . && ruff check .'.
  • Comments are complete English sentences with backticked code references.
  • pytest.skip messages are lowercase without terminal period (framework convention).
  • No blank line between function signature and body when there is no docstring or comment.
  • Blank lines around control-flow statements.
  • Blank line before return when not directly following a control-flow statement.
  • Docstrings follow PEP 257.
  • Type hints are present on new dataclasses (Param, Op) and on every public function.

Testing

  • pytest was run locally on every supported platform — see the platform table above.
  • Reasons for any platform that could not be tested — all supported platforms were tested; Ascend's outer status was 137 after pytest had already printed a passing summary.
  • New functionality has matching tests under tests/test_torch_ops.py.
  • Tests use pytest.mark.parametrize correctly: dependent parameters share one decorator (("dtype", "rtol", "atol")); independent parameters use separate decorators.
  • Default dtype / device parameterization is relied on; op_meta and shape are added with explicit parametrize.
  • N/A: no flaky tests under parallelism added.
  • N/A: this is a feature PR, not a bug fix — no regression test required.

Build, CI, and Tooling

  • Builds cleanly from a fresh directory with pip install .[dev] on every supported platform — see the platform table above.
  • compile_commands.json regenerates (no change to pyproject.toml's CMAKE_EXPORT_COMPILE_COMMANDS=ON).
  • N/A: no new backends / devices added.
  • N/A: CUDA-like GPU mutual exclusion not changed.
  • Both CI workflows (clang-format.yml, ruff.yml) are expected to be green — verified locally per-commit.
  • New build dependency pyyaml added to pyproject.toml's [build-system].requires.

Documentation

  • N/A: README.md, CONTRIBUTING.md, and developer workflow are unchanged for end users; the codegen docstring documents internal behaviour for maintainers.
  • Codegen docstring + per-section comments explain the codegen pipeline; the PyTorch slot-8 convention is documented in the _PYTORCH_SLOT comment.
  • N/A: no user-visible breaking change.

Security and Safety

  • No secrets, access tokens, internal URLs, customer data, or personal hardware identifiers have been committed.
  • No third-party code added (the codegen reads aten/native_functions.yaml from PyTorch's GitHub but does not vendor it).
  • No unsafe pointer arithmetic, uninitialized reads, or missing bounds checks introduced.

@voltjia voltjia requested a review from a team May 9, 2026 07:58
@voltjia voltjia force-pushed the feat/torch-codegen branch from 9cb7b73 to be71261 Compare May 15, 2026 14:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant