feat: YAML-driven torch op codegen with canonical naming and exposed semantic params#595
Open
voltjia wants to merge 6 commits into
Open
feat: YAML-driven torch op codegen with canonical naming and exposed semantic params#595voltjia wants to merge 6 commits into
voltjia wants to merge 6 commits into
Conversation
9cb7b73 to
be71261
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
scripts/generate_torch_ops.py(~920 lines) — a YAML-driven codegen that consumes PyTorch'saten/native_functions.yamland emits an InfiniOps base class plus a slot-8 PyTorch backend per op listed inscripts/torch_ops.yaml(~459 ops, generating 507 overloads across 437 canonical classes).src/CMakeLists.txt) underWITH_TORCH=ON: invoke at configure time, globgenerated/torch/*.cc, addgenerated/to public include paths, install the per-op metadata JSON alongside the bindings.scripts/generate_wrappers.py) to scangenerated/base/andgenerated/torch/, fix pybind11 overload ordering (specific → permissive), preservestd::vector<int64_t>parameters that libclang misreports asint, and routeactive_implementation_indicesthrough a graceful unknown-device path.detail::ListContainsinsrc/operator.h,TryDeviceTypeFromStringinsrc/pybind11_utils.h) so generated bindings handle devices an op does not implement without aborting.tests/test_torch_ops.pythat readsgenerated/torch_ops_metadata.jsonand exercises every generated op across three shapes and three dtypes; widentests/conftest.pyto handle non-floating outputs andequal_nan.Sigmoidhelper insrc/native/cuda/ops/swiglu/kernel.cuhintodetail::so it does not collide with the auto-generatedinfini::ops::Sigmoidoperator class.pyyamlto[build-system].requiresso CMake can run the codegen duringpip install.Codegen design choices driven by review feedback collected across all 513 base PRs against
feat/torch-codegen:_grad_input,_outtensor,_n_scalar,_values,_x,_l,_q,_u,_output) no longer leak into InfiniOps class names. Multiple ATen overloads of the same base op share a single class, with overloadedoperator()methods.bool upper,bool transpose,bool unitriangular(triangular_solve),int diagonal(triu),str ord(linalg_matrix_norm),int non the chebyshev/hermite polynomial families, etc. are no longer hidden because they have an ATen default — they are now visible in the generatedoperator()and forwarded to ATen.Motivation
Replaces 500+ hand-written
src/base/<op>.hheaders with a single declarative pipeline driven from PyTorch's schema. Each commit is single-purpose and individually passesruff check,ruff format --check, andclang-format(version 21).The previous iteration (
feat/torch-codegen-legacy, preserved on the remote) generated suffixed names that reviewers consistently flagged as bad public API (per inline comments on PRs #280, #283-#290, #509, #563-#589). It also hid semantically critical parameters and did not store scalars as members, requiring hand-written corrections in every base PR. This refactor moves those corrections into the codegen itself, so future regenerations produce the reviewer-preferred shape directly. The 77 PRs that were for non-canonical overload names have been closed; the remaining 333 keep + 103 promote PRs have their content regenerated to match the new codegen output.Closes #
Type of Change
feat— new feature / new operator / new platformN/A:
fix— this is a feature PR rather than a bug-fix-only PR.N/A:
perf— no runtime hot-path performance change is intended.N/A:
refactor— the primary change is generated PyTorch backend support.N/A:
test— this PR adds tests, but it is not a test-only PR.N/A:
docs— documentation-only PR does not apply.build/ci— build system or CI configurationN/A:
chore— tooling-only PR does not apply.N/A: Breaking change (requires a
!in the Conventional Commits prefix or aBREAKING CHANGE:footer).Platforms Affected
WITH_CPU)WITH_NVIDIA)WITH_ILUVATAR)WITH_METAX)WITH_CAMBRICON)WITH_MOORE)WITH_ASCEND)WITH_TORCH)Test Results on Supported Platforms
All supported platforms were validated with the direct profile. The generated PyTorch backend tests are included in the collected pytest set on every platform below.
pytestResult9147 passed, 7462 skipped, 81 warnings in 373.45s (0:06:13)16609items collected; PyTorch backend tests included7371 passed, 7738 skipped, 81 warnings in 377.58s (0:06:17)15109items collected; native and PyTorch backend tests included8639 passed, 6470 skipped, 81 warnings in 382.57s (0:06:22)15109items collected; PyTorch backend tests included5852 passed, 8891 skipped, 169 warnings in 996.13s (0:16:36)14743items collected; PyTorch backend tests included8422 passed, 6687 skipped, 99 warnings in 600.34s (0:10:00)15109items collected; PyTorch backend tests included7344 passed, 7705 skipped, 98 warnings in 573.57s (0:09:33)15049items collected; pytest completed successfully; outer container returned 137 after pytest summaryFull `pytest` output (optional)
Benchmark / Performance Impact
N/A. This PR adds a codegen pipeline, not a runtime hot-path change. Generated PyTorch backends call
at::<op>_out(...)directly, so per-op performance matches a hand-written ATen-backed op.Notes for Reviewers
feat/torch-codegenintegration branch. The previous content is preserved atfeat/torch-codegen-legacyfor reference.feat/torch-codegenhave been processed: 77 redundant overload PRs closed, 333 keep + 103 promote PRs scheduled to be force-pushed with regenerated content matching the new canonical naming and parameter shape.> 0to avoid a partial-specialization-after-instantiation conflict withOperator<Op>at index 0.src/base/<op>.hcontinues to shadowgenerated/base/<op>.h(existence-based; no signature compatibility check). The four pre-existing hand-written bases that do not match the ATen-derived signature (add,linear,matmul,mul) are excluded fromscripts/torch_ops.yamland left to their existing hand-written infrastructure.Tensor?,Scalar?,int?,float?) remain hidden for now — exposing them properly requires threadingstd::optionalthrough to ATen, which is a separable refactor.Checklist
Title, Branch, and Commits
<type>/xxx-yyyy-zzzz—feat/torch-codegen.master— branch is rebased cleanly on top of currentmaster.fixup!/squash!/wipcommits remain.Scope and Design
merge_base_branches.pyfrom the legacy branch) was dropped.TODOwithout an owner.infini::ops::<Pascal>classes are documented via the codegen's docstring.General Code Hygiene (applies to all languages)
C++ Specific (if C++ files changed)
clang-format(version 21) is clean for every modified.h,.cc,.cuhfile. Verified viagit rebase master --exec 'clang-format --dry-run --Werror $(git diff HEAD~1 --name-only -- "*.h" "*.cc" "*.cuh")'.N/A:
clang-tidywas not run; this repository's PR validation for this generated C++ surface isclang-format, full cross-platform build, and full cross-platform pytest._outform).assert. Generated code uses ATen which itself usesTORCH_CHECK; that is consistent with the existing torch backend pattern.src/base/<op>.h(auto-generated undergenerated/base/) inheritingOperator<Op>; PyTorch backend specializes at slot 8.new/delete.Python Specific (if Python files changed)
ruff checkis clean for the entire repo.ruff format --checkis clean for the entire repo. Verified per-commit viagit rebase master --exec 'ruff format --check . && ruff check .'.pytest.skipmessages are lowercase without terminal period (framework convention).returnwhen not directly following a control-flow statement.Param,Op) and on every public function.Testing
pytestwas run locally on every supported platform — see the platform table above.tests/test_torch_ops.py.pytest.mark.parametrizecorrectly: dependent parameters share one decorator (("dtype", "rtol", "atol")); independent parameters use separate decorators.dtype/deviceparameterization is relied on;op_metaandshapeare added with explicitparametrize.Build, CI, and Tooling
pip install .[dev]on every supported platform — see the platform table above.compile_commands.jsonregenerates (no change topyproject.toml'sCMAKE_EXPORT_COMPILE_COMMANDS=ON).clang-format.yml,ruff.yml) are expected to be green — verified locally per-commit.pyyamladded topyproject.toml's[build-system].requires.Documentation
README.md,CONTRIBUTING.md, and developer workflow are unchanged for end users; the codegen docstring documents internal behaviour for maintainers._PYTORCH_SLOTcomment.Security and Safety
aten/native_functions.yamlfrom PyTorch's GitHub but does not vendor it).