From 5d6e1d86f9bbf2614d638d75d7cc7543cb7d5e8b Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Sat, 16 May 2026 09:40:53 +0800 Subject: [PATCH 1/4] build: add PyYAML build dependency --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 959699f90..a18e0e1af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["scikit-build-core", "pybind11", "libclang"] +requires = ["scikit-build-core", "pybind11", "libclang", "pyyaml"] build-backend = "scikit_build_core.build" [project] From 21d376f688c9bd63eaf5bcf3fb6692383c0a7e32 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Sat, 16 May 2026 09:41:17 +0800 Subject: [PATCH 2/4] feat(scripts): add YAML-driven torch op codegen --- scripts/generate_torch_ops.py | 1102 +++++++++++++++++++++++++++++++++ scripts/torch_ops.yaml | 540 ++++++++++++++++ 2 files changed, 1642 insertions(+) create mode 100644 scripts/generate_torch_ops.py create mode 100644 scripts/torch_ops.yaml diff --git a/scripts/generate_torch_ops.py b/scripts/generate_torch_ops.py new file mode 100644 index 000000000..bcbe64f0f --- /dev/null +++ b/scripts/generate_torch_ops.py @@ -0,0 +1,1102 @@ +"""Generate InfiniOps PyTorch wrappers from ATen `native_functions.yaml`. + +For each op listed in `scripts/torch_ops.yaml`, this script finds the `.out` +variant in PyTorch's locally installed `native_functions.yaml`, parses its schema, +and emits: + - `generated/base/.h` — the InfiniOps base class + `class : public Operator<>`, with constructors and pure-virtual + `operator()` overloads mirroring the selected ATen schemas. + - `generated/torch//.h` and `.cc` — the PyTorch backend + `Operator<, kDev, 8>` that calls `at::_out(out, ...)`. + - `generated/torch_ops_metadata.json` — the kind (`unary` / `binary` / + `binary_alpha`) of every successfully-generated op, consumed by the + parametrized test suite. + +Slot 8 is the reserved convention for PyTorch backends; slots 0-7 are +left for native or vendor implementations. (The slot must also be > 0 +to side-step a partial-specialization-after-instantiation conflict with +the primary template `Operator<>` instantiated at index 0.) + +The generated files are not committed; CMake regenerates them at configure +time when `WITH_TORCH=ON`. +""" + +import argparse +import dataclasses +import importlib.util +import json +import os +import pathlib +import re +import shutil +import subprocess +import sys + +import yaml + +_SCRIPTS_DIR = pathlib.Path(__file__).resolve().parent +_REPO_ROOT = _SCRIPTS_DIR.parent +_OPS_YAML_PATH = _SCRIPTS_DIR / "torch_ops.yaml" +_BASE_DIR = _REPO_ROOT / "src" / "base" +_GENERATED_DIR = _REPO_ROOT / "generated" +_GENERATED_BASE_DIR = _GENERATED_DIR / "base" +_GENERATED_TORCH_DIR = _GENERATED_DIR / "torch" +_METADATA_PATH = _GENERATED_DIR / "torch_ops_metadata.json" + +# Reserved slot for PyTorch backends. Native and vendor implementations +# claim slots 0-7; PyTorch wrappers always live at 8. +_PYTORCH_SLOT = 8 + +# ATen uses symbolic names for some `int`/`float` defaults (e.g. +# `reduction=Mean`). Map them to C++ identifiers usable in a call. +_ENUM_DEFAULTS = { + "Mean": "at::Reduction::Mean", + "Sum": "at::Reduction::Sum", + "Contiguous": "at::MemoryFormat::Contiguous", +} + +# Default PyTorch schema label used only in diagnostics when CMake does not +# provide the locally installed torch version. Codegen reads the actual schema +# from installed `torchgen/packaged/ATen/native/native_functions.yaml`. +_DEFAULT_PYTORCH_VERSION = "v2.4.0" + +# Order matches the device list in existing hand-written torch backends +# (see `src/torch/add/add.cc`). +_DEVICE_TYPES = ( + "kCpu", + "kNvidia", + "kCambricon", + "kAscend", + "kMetax", + "kMoore", + "kIluvatar", + "kKunlun", + "kHygon", + "kQy", +) + +# YAML scalar-type tokens → C++ types. Reference types (e.g. `const Scalar&`) +# are not used so the generated signatures match the existing hand-written +# ones, which pass by value to keep pybind11 binding generation simple. +_SCALAR_TYPE_MAP = { + # `at::Scalar` is implicitly constructible from `double`, so we expose + # scalars as `double` in the base class to keep it torch-independent. + "Scalar": "double", + "int": "int64_t", + "bool": "bool", + "float": "double", + # `SymInt` / `SymInt[]` exist for `torch.compile` internals; at runtime + # they're just `int64`/IntArrayRef. + "SymInt": "int64_t", + # `str` for required string params (e.g. `index_reduce.reduce`). + # `std::string` marshals through pybind11 cleanly and converts + # implicitly to ATen's `c10::string_view`. + "str": "std::string", +} + +# `Dimname` overloads (named-tensor dim) are skipped — passing them +# from Python to ATen requires a wrapper conversion through +# `at::Dimname::fromSymbol(...)` that doesn't fit the cleanly-rendered +# 1:1 arg model, and named tensors remain experimental in PyTorch. +# The int-dim overload is always emitted alongside, so we lose nothing +# user-visible. + +# Optional ATen types we hide from the user-facing API and pass as a +# typed empty optional at the call site. Covers the common "full +# default" case for most reductions and activations. We use a typed +# `c10::optional{}` rather than bare `at::nullopt` so the compiler +# can disambiguate ops with multiple `_out` overloads (e.g. `clamp_out` +# accepts both `optional` and `optional` for `min`/`max`). +_NULLOPT_BY_TYPE = { + "Scalar?": "c10::optional{}", + "int?": "c10::optional{}", + "bool?": "c10::optional{}", + "float?": "c10::optional{}", + "str?": "c10::optional{}", + "ScalarType?": "c10::optional{}", + "MemoryFormat?": "c10::optional{}", + "Layout?": "c10::optional{}", + "Device?": "c10::optional{}", + "Generator?": "c10::optional{}", + "Tensor?": "c10::optional{}", + "Tensor?[]": "c10::List>{}", + "int[]?": "c10::optional{}", + "int[1]?": "c10::optional{}", + "int[2]?": "c10::optional{}", + "int[3]?": "c10::optional{}", + "SymInt?": "c10::optional{}", + "SymInt[]?": "c10::optional{}", + "SymInt[1]?": "c10::optional{}", + "SymInt[2]?": "c10::optional{}", + "SymInt[3]?": "c10::optional{}", + "float[]?": "c10::optional>{}", +} +_HARDCODE_NULLOPT_TYPES = frozenset(_NULLOPT_BY_TYPE) + + +@dataclasses.dataclass +class Param: + name: str + aten_type: str + default: str | None + keyword_only: bool + + @property + def is_tensor(self) -> bool: + # Real tensors only. `Tensor?` is optional and falls through to + # the hidden-param path (substituted with `at::nullopt`). + + return self.aten_type == "Tensor" or self.aten_type.startswith("Tensor(") + + @property + def is_mutable_tensor(self) -> bool: + # Mutable tensors carry `!` in their alias annotation, e.g. + # `Tensor(a!)`. + + return self.is_tensor and "!" in self.aten_type + + @property + def is_out(self) -> bool: + # In ATen `_out` schemas, output tensors are keyword-only mutable tensor + # params. Some mutable tensors are real inputs (`running_mean` / + # `running_var` in `_batch_norm_with_update`) and must stay in schema + # order, so mutability alone is not enough. + + return self.is_mutable_tensor and self.keyword_only + + @property + def is_hardcoded_nullopt(self) -> bool: + """If `True`, the param is omitted from the user-facing API and + passed as `at::nullopt` to ATen.""" + + return self.aten_type in _HARDCODE_NULLOPT_TYPES + + @property + def is_hidden(self) -> bool: + """True if the param is omitted from the user-facing API. + + Default-valued non-optional params (\\`bool\\`, \\`int\\`, \\`float\\`, + \\`str\\`, \\`int[N]\\`, …) used to be hidden as a convenience, but + reviewers consistently flagged the resulting omissions — + \\`bool upper/transpose/unitriangular\\` on \\`triangular_solve\\`, + \\`int diagonal\\` on \\`triu\\`, \\`str ord\\` on \\`linalg_matrix_norm\\`, + \\`int n\\` on the special chebyshev family, etc. — as missing + semantic controls. They are now exposed and forwarded to ATen. + + Optional ATen types (\\`Tensor?\\`, \\`Scalar?\\`, \\`int?\\`, …) remain + hidden for now — exposing them would require teaching the torch + source to thread \\`std::optional\\` through to ATen, which is a + separate refactor. The same goes for ATen-internal types like + \\`Generator?\\`/\\`Layout?\\` that have no InfiniOps analogue. + """ + + return self.is_hardcoded_nullopt + + def hidden_value(self) -> str: + """C++ literal substituted for a hidden param in the ATen call.""" + + if self.is_hardcoded_nullopt: + return _NULLOPT_BY_TYPE[self.aten_type] + + if self.default == "True": + return "true" + + if self.default == "False": + return "false" + + if self.aten_type.startswith(("int[", "SymInt[")) and self.default is not None: + # `int[N]=[a, b, c]` → `{a, b, c}`; `int[N]=0` (scalar default + # for list type) → `{0, 0, ...}` replicated to size N. + if self.default.startswith("["): + return "{" + self.default[1:-1] + "}" + + size_match = re.search(r"\[(\d+)\]", self.aten_type) + n = int(size_match.group(1)) if size_match else 1 + + return "{" + ", ".join([self.default] * n) + "}" + + if self.aten_type == "str" and self.default is not None: + # YAML uses single-quoted strings (e.g. `'none'`); C++ char + # literals also use single quotes, so swap to doubles. + + return '"' + self.default.strip("'\"") + '"' + + if self.aten_type in {"int", "float", "SymInt"} and self.default is not None: + # Translate known ATen enum defaults to their C++ identifiers. + + return _ENUM_DEFAULTS.get(self.default, self.default) + + raise AssertionError( + f"param {self.name!r} of type {self.aten_type!r} with default " + f"{self.default!r} is not hidden" + ) + + @property + def cpp_type(self) -> str: + if self.is_tensor: + # `Tensor[]` / `Tensor(a!)[]` would need `std::vector` and a + # different ATen call shape — not yet supported, so reject so the + # whole overload gets skipped instead of emitting code that calls + # `at::_out(at::Tensor, ...)` against an `at::TensorList` + # signature. + if self.aten_type.endswith("[]"): + raise NotImplementedError( + f"`Tensor[]` param {self.name!r} not supported yet" + ) + + return "Tensor" + + if self.is_hidden: + # Not exposed — the ATen call substitutes a hardcoded value + # so the `cpp_type` is irrelevant. + + return "void" + + bare = self.aten_type.rstrip("?") + # Required `int[N]` / `SymInt[N]` (no default) — pybind11 accepts + # a Python list of ints into `std::vector`, which ATen + # promotes to `IntArrayRef` implicitly. + if bare.startswith(("int[", "SymInt[")) or bare in {"int[]", "SymInt[]"}: + return "std::vector" + + try: + return _SCALAR_TYPE_MAP[bare] + except KeyError as exc: + raise NotImplementedError( + f"unsupported ATen type {self.aten_type!r} for param {self.name!r}" + ) from exc + + +@dataclasses.dataclass +class Op: + aten_name: str + overload: str + params: list[Param] + + @property + def pascal_name(self) -> str: + return _snake_to_pascal(self.infini_name) + + @property + def infini_name(self) -> str: + """InfiniOps public op name. + + ATen disambiguates `_out` overloads with suffixes like `Tensor_Tensor_out`, + `out_x`, `forward_output`, `grad_input`, but reviewers consistently + flag those suffixes as bad public-API naming when they leak into + InfiniOps class names. Different ATen overloads of the same base op + become overloaded `operator()` methods on a single class instead. When + two overloads collapse to the same visible C++ signature after hidden + defaults, `_dedupe_visible_overloads` keeps only one. + + ATen-internal leading underscores and in-place trailing underscores are + also normalized so generated class names do not collide with existing + public ops (`_softmax` → `AtenSoftmax`, `add_` → `AddInplace`). + """ + + return _public_op_name(self.aten_name) + + @property + def is_inplace(self) -> bool: + """True for ATen's single-underscore in-place variants.""" + + return _is_inplace_aten_name(self.aten_name) + + @property + def tensor_params(self) -> list[Param]: + return [p for p in self.params if p.is_tensor] + + @property + def out_params(self) -> list[Param]: + """Mutable tensor outputs. Most ops have one (`Tensor(a!) out`); + multi-output ops like `frexp` or `sort` have several + (`Tensor(a!) values`, `Tensor(b!) indices`).""" + + if self.is_inplace: + return [self.params[0]] + + return [p for p in self.params if p.is_out] + + @property + def out_param(self) -> Param: + """Single-output convenience. Asserts there's exactly one.""" + outs = self.out_params + assert len(outs) == 1, f"op {self.aten_name!r} has {len(outs)} out tensors" + + return outs[0] + + @property + def visible_params(self) -> list[Param]: + """Params the wrapper exposes to the user; hidden ones (hardcoded + optional nullopt, default-`False`/`True` bools) are filtered.""" + + return [p for p in self.params if not p.is_hidden] + + @property + def is_testable(self) -> bool: + """Cheap structural check: at least one out tensor, and the first + constructor parameter is a tensor. The latter is needed because + `Operator::Make(Tensor tensor, Args... args)` dispatches on + `tensor.device()`, so an op like `pow.Scalar_out(Scalar self, + Tensor exponent, *, Tensor(a!) out)` cannot be wired up without + a separate dispatch path. Generators like `arange` / `linspace` + also fall under this rule (no input tensors at all).""" + + if not self.out_params: + return False + + if self.is_inplace: + return self.params[0].is_mutable_tensor + + # `params` includes out tensors at the end; check the first + # non-out param. If there are no non-out params (`empty.out`, + # `arange.out`), this op also fails the dispatch precondition. + non_out = [p for p in self.params if not p.is_out] + + if not non_out: + return False + + return non_out[0].is_tensor + + +_FUNC_RE = re.compile( + r"^(?P[a-zA-Z_][a-zA-Z0-9_]*)" + r"(?:\.(?P\w+))?" + r"\((?P.*)\)\s*->\s*.+$" +) + +_ARG_RE = re.compile( + r"^(?P\S+(?:\([^)]*\))?\??)" # type with optional alias and `?` + r"\s+(?P\w+)" + r"(?:\s*=\s*(?P.+))?$" +) + + +def _parse_func(func_str: str) -> Op: + m = _FUNC_RE.match(func_str) + + if not m: + raise ValueError(f"could not parse func: {func_str!r}") + + return Op( + aten_name=m.group("name"), + overload=m.group("overload") or "", + params=_parse_args(m.group("args")), + ) + + +def _parse_args(args_str: str) -> list[Param]: + params: list[Param] = [] + keyword_only = False + + for token in _split_args(args_str): + if token == "*": + keyword_only = True + continue + + params.append(_parse_one_arg(token, keyword_only)) + + return params + + +def _split_args(args_str: str) -> list[str]: + """Split on top-level commas, respecting `(...)` and `[...]`.""" + parts: list[str] = [] + depth = 0 + current: list[str] = [] + + for ch in args_str: + if ch in "([": + depth += 1 + current.append(ch) + elif ch in ")]": + depth -= 1 + current.append(ch) + elif ch == "," and depth == 0: + piece = "".join(current).strip() + + if piece: + parts.append(piece) + + current = [] + else: + current.append(ch) + + tail = "".join(current).strip() + + if tail: + parts.append(tail) + + return parts + + +def _parse_one_arg(token: str, keyword_only: bool) -> Param: + m = _ARG_RE.match(token) + + if not m: + raise ValueError(f"could not parse arg: {token!r}") + + name = m.group("name") + # ATen names the first tensor parameter `self` (matching the + # method-style \`tensor.abs()\` convention). InfiniOps uses + # \`input\` for the primary tensor input across all hand-written + # bases (\`Add\`, \`Gemm\`, …) per \`CONTRIBUTING.md\` §C++. + # Rename at parse time so the generated headers match. + if name == "self": + name = "input" + + return Param( + name=name, + aten_type=m.group("type"), + default=m.group("default"), + keyword_only=keyword_only, + ) + + +def _snake_to_pascal(s: str) -> str: + return "".join(p.capitalize() for p in s.split("_")) + + +def _is_inplace_aten_name(name: str) -> bool: + """Return whether `name` is an ATen in-place operator name.""" + + return name.endswith("_") and not name.endswith("__") + + +def _public_op_name(aten_name: str) -> str: + """Map ATen-only spelling to stable InfiniOps public names.""" + + public_name = aten_name + prefix = "" + + if public_name.startswith("_"): + prefix = "aten_" + public_name = public_name.lstrip("_") + + if _is_inplace_aten_name(public_name): + public_name = public_name[:-1] + "_inplace" + + return prefix + public_name + + +def _base_path(op_name: str) -> pathlib.Path: + return _BASE_DIR / f"{op_name}.h" + + +def _load_aten_yaml(version: str) -> str: + """Return the `native_functions.yaml` bundled with installed `torchgen`. + + `WITH_TORCH=ON` already requires a local PyTorch installation. PyTorch + wheels ship the matching ATen schema under `torchgen/packaged`, including + vendor forks, so codegen should not depend on fetching PyTorch sources from + GitHub during CI builds. + """ + + packaged = _load_packaged_aten_yaml() + + if packaged is None: + raise RuntimeError( + "could not find installed `torchgen` packaged " + f"`native_functions.yaml` for PyTorch schema {version!r}" + ) + + print( + "using installed `torchgen` packaged `native_functions.yaml` " + f"for PyTorch schema {version}.", + file=sys.stderr, + ) + + return packaged + + +def _load_packaged_aten_yaml() -> str | None: + """Return the `native_functions.yaml` bundled with installed `torchgen`. + + PyTorch wheels install `torchgen/packaged/ATen/native/native_functions.yaml`; + using it lets offline CI images generate wrappers against the exact schema + shipped with their installed torch fork. + """ + + spec = importlib.util.find_spec("torchgen") + + if spec is None or spec.submodule_search_locations is None: + return None + + for root in spec.submodule_search_locations: + path = ( + pathlib.Path(root) + / "packaged" + / "ATen" + / "native" + / "native_functions.yaml" + ) + + if path.is_file(): + return path.read_text() + + return None + + +def _find_out_entries(entries: list[dict], op_name: str) -> list[dict]: + """Return all out-variant entries for `op_name`, with the bare + `.out(` form first and overload-suffixed variants + (e.g. `pow.Tensor_Tensor_out(`, `kthvalue.values(`) after. An + entry counts as an out-variant when it (a) is named + `.out`, (b) ends in `_out`, or (c) carries a + `Tensor(!)` mutability annotation — that last case covers + multi-output ops named after their output tensors + (`kthvalue.values`, `mode.values`, …).""" + bare_prefix = f"{op_name}.out(" + op_overload = re.compile(rf"^{re.escape(op_name)}\.\w+\(") + mut_tensor = re.compile(r"Tensor\([a-z]!\)") + bare: list[dict] = [] + others: list[dict] = [] + + for entry in entries: + func = entry.get("func", "") + + if func.startswith(bare_prefix): + bare.append(entry) + elif op_overload.match(func) and ( + func.split("(", 1)[0].endswith("_out") or mut_tensor.search(func) + ): + others.append(entry) + + return bare + others + + +def _format_signature(op: Op, *, include_defaults: bool = False) -> str: + parts = [] + + for param in op.visible_params: + prefix = "" if param.is_mutable_tensor else "const " + text = f"{prefix}{param.cpp_type} {param.name}" + + if include_defaults and param.default is not None: + text += f" = {_translate_default(param)}" + + parts.append(text) + + return ", ".join(parts) + + +def _visible_signature_key(op: Op) -> tuple[str, ...]: + """C++ overload identity for the user-facing API. + + Parameter names and top-level `const` do not distinguish C++ overloads, so + only the exposed C++ type sequence participates in duplicate detection. + """ + + return tuple(param.cpp_type for param in op.visible_params) + + +def _canonical_overload_score(index: int, op: Op) -> tuple[bool, int, int, str, int]: + """Sort key for duplicate visible signatures. + + Prefer the canonical unsuffixed InfiniOps name, then the schema that hides + fewer ATen-only defaults, then the shorter deterministic name. + """ + + return ( + op.infini_name != op.aten_name, + sum(param.is_hidden for param in op.params), + len(op.infini_name), + op.infini_name, + index, + ) + + +def _dedupe_visible_overloads(ops: list[Op]) -> tuple[list[Op], list[tuple[Op, Op]]]: + """Drop overloads that collapse to the same visible C++ signature. + + Returns the selected overloads in the original schema order plus a list of + `(skipped, kept)` duplicate pairs for diagnostics. + """ + winners: dict[tuple[str, ...], tuple[int, Op]] = {} + duplicates: list[tuple[Op, tuple[str, ...]]] = [] + + for index, op in enumerate(ops): + key = _visible_signature_key(op) + current = winners.get(key) + + if current is None: + winners[key] = (index, op) + continue + + current_index, current_op = current + + if _canonical_overload_score(index, op) < _canonical_overload_score( + current_index, current_op + ): + duplicates.append((current_op, key)) + winners[key] = (index, op) + else: + duplicates.append((op, key)) + + selected_indices = {index for index, _ in winners.values()} + selected = [op for index, op in enumerate(ops) if index in selected_indices] + duplicate_pairs = [ + (skipped, winners[key][1]) + for skipped, key in duplicates + if winners[key][1] is not skipped + ] + + return selected, duplicate_pairs + + +def _translate_default(param: Param) -> str: + """Translate a YAML default literal to a C++ literal.""" + raw = param.default + + if raw == "True": + return "true" + + if raw == "False": + return "false" + + if raw == "None": + return "{}" + + return raw # numeric literals (`0`, `1`, `1.0`) pass through + + +def _generate_base_header(name: str, ops: list[Op]) -> str: + pascal = _snake_to_pascal(name) + + member_decls = [] + tensor_member_order = [] + seen_tensor_members = set() + scalar_member_order = [] + scalar_member_types = {} + + for op in ops: + for param in op.tensor_params: + if param.name in seen_tensor_members: + continue + + seen_tensor_members.add(param.name) + tensor_member_order.append(param.name) + member_decls.append(f" Tensor::Shape {param.name}_shape_;") + member_decls.append(f" Tensor::Strides {param.name}_strides_;") + member_decls.append(f" DataType {param.name}_type_;") + + # Visible non-tensor params (scalars, strings, vectors) are also + # stored on the base so backends can dispatch on them later — not + # only at the moment `operator()` is invoked. Reviewers flagged + # this on multiple PRs (e.g. `n` on + # `special_chebyshev_polynomial_v_n_scalar`). Same-named params + # across overloads must share a type; if they conflict, the second + # overload's member is dropped (later constructors leave it + # default-initialised). + for param in op.visible_params: + if param.is_tensor or param.name in scalar_member_types: + continue + + scalar_member_order.append(param.name) + scalar_member_types[param.name] = param.cpp_type + member_decls.append(f" {param.cpp_type} {param.name}_{{}};") + + member_decls.append(" int device_index_{0};") + + constructors = [] + calls = [] + + for op in ops: + init_pieces = [] + tensor_params = {param.name: param for param in op.tensor_params} + scalar_params = { + param.name: param + for param in op.visible_params + if not param.is_tensor + and scalar_member_types.get(param.name) == param.cpp_type + } + + for param_name in tensor_member_order: + param = tensor_params.get(param_name) + + if param is None: + continue + + init_pieces.append(f" {param.name}_shape_{{{param.name}.shape()}}") + init_pieces.append( + f" {param.name}_strides_{{{param.name}.strides()}}" + ) + init_pieces.append(f" {param.name}_type_{{{param.name}.dtype()}}") + + for param_name in scalar_member_order: + param = scalar_params.get(param_name) + + if param is None: + continue + + init_pieces.append(f" {param.name}_{{{param.name}}}") + + # All out tensors share a device; use the first one. Keep this last + # so initializer order follows the member declaration order. + init_pieces.append( + f" device_index_{{{op.out_params[0].name}.device().index()}}" + ) + + init_list = ",\n".join(init_pieces).lstrip() + constructors.append( + f" {pascal}({_format_signature(op)})\n : {init_list} {{}}" + ) + calls.append(f" virtual void operator()({_format_signature(op)}) const = 0;") + + return _BASE_TEMPLATE.format( + name_uc=name.upper(), + pascal=pascal, + constructors="\n\n".join(constructors), + op_calls="\n\n".join(calls), + member_decls="\n\n".join(member_decls), + ) + + +def _generate_torch_header(name: str, ops: list[Op]) -> str: + pascal = _snake_to_pascal(name) + op_calls = "\n\n".join( + f" void operator()({_format_signature(op)}) const override;" for op in ops + ) + + return _TORCH_HEADER_TEMPLATE.format( + name_uc=name.upper(), + name=name, + pascal=pascal, + op_calls=op_calls, + slot=_PYTORCH_SLOT, + ) + + +def _generate_torch_method_source(name: str, op: Op) -> str: + pascal = _snake_to_pascal(name) + conversion_lines = [] + + for param in op.tensor_params: + data_expr = ( + f"{param.name}.data()" + if param.is_mutable_tensor + else f"const_cast({param.name}.data())" + ) + conversion_lines.append( + f" auto at_{param.name} = ToAtenTensor(\n" + f" {data_expr}, {param.name}_shape_, {param.name}_strides_,\n" + f" {param.name}_type_, device_index_);" + ) + + def _render_arg(p): + if p.is_hidden: + return p.hidden_value() + + if p.is_tensor: + return f"at_{p.name}" + + return p.name + + if op.is_inplace: + # In-place ATen calls keep the mutable input in positional order, + # unlike `_out` calls which place output tensors first. + input_param = op.params[0] + arg_order = op.params[1:] + aten_call = ( + f"at_{input_param.name}.{op.aten_name}" + f"({', '.join(_render_arg(p) for p in arg_order)})" + ) + else: + # ATen `_out` form puts all out tensors first, then non-out params + # in YAML order. Hardcoded-nullopt params become `at::nullopt`. + arg_order = op.out_params + [p for p in op.params if not p.is_out] + aten_call = ( + f"at::{op.aten_name}_out({', '.join(_render_arg(p) for p in arg_order)})" + ) + + return _TORCH_METHOD_TEMPLATE.format( + pascal=pascal, + op_call_signature=_format_signature(op), + tensor_conversions="\n".join(conversion_lines), + # The generated call expression resolves the right kernel via C++ + # overload resolution from the argument types we pass. + aten_call=aten_call, + slot=_PYTORCH_SLOT, + ) + + +def _generate_torch_source(name: str, ops: list[Op]) -> str: + pascal = _snake_to_pascal(name) + methods = "\n\n".join(_generate_torch_method_source(name, op) for op in ops) + # Guard each explicit instantiation by the matching `WITH_` macro + # so a build that only enables a subset of devices does not pay the + # ATen template-instantiation cost (and memory pressure) for the + # devices it does not link against. Each macro is set by + # `target_compile_definitions` in `src/CMakeLists.txt`. + instantiations = "\n".join( + f"#ifdef WITH_{dev.removeprefix('k').upper()}\n" + f"template class Operator<{pascal}, Device::Type::{dev}, {_PYTORCH_SLOT}>;\n" + f"#endif" + for dev in _DEVICE_TYPES + ) + + return _TORCH_SOURCE_TEMPLATE.format( + name=name, + methods=methods, + instantiations=instantiations, + ) + + +_BASE_TEMPLATE = """\ +#ifndef INFINI_OPS_BASE_{name_uc}_H_ +#define INFINI_OPS_BASE_{name_uc}_H_ + +#include "operator.h" + +namespace infini::ops {{ + +class {pascal} : public Operator<{pascal}> {{ + public: +{constructors} + +{op_calls} + + protected: +{member_decls} +}}; + +}} // namespace infini::ops + +#endif +""" + + +_TORCH_HEADER_TEMPLATE = """\ +#ifndef INFINI_OPS_TORCH_{name_uc}_H_ +#define INFINI_OPS_TORCH_{name_uc}_H_ + +#include "base/{name}.h" + +namespace infini::ops {{ + +template +class Operator<{pascal}, kDev, {slot}> : public {pascal} {{ + public: + using {pascal}::{pascal}; + +{op_calls} +}}; + +}} // namespace infini::ops + +#endif +""" + + +_TORCH_METHOD_TEMPLATE = """\ +template +void Operator<{pascal}, kDev, {slot}>::operator()({op_call_signature}) const {{ +{tensor_conversions} + + {aten_call}; +}} +""" + + +_TORCH_SOURCE_TEMPLATE = """\ +#include "torch/{name}/{name}.h" + +#include "torch/tensor_.h" + +namespace infini::ops {{ + +{methods} + +{instantiations} + +}} // namespace infini::ops +""" + + +def _find_clang_format() -> str | None: + """Return the path to a `clang-format` binary, or `None` if unavailable. + + Generated files live under `generated/`, which is gitignored, so offline CI + containers should not block or try network access just to format them. + """ + + found = shutil.which("clang-format") + + if found: + return found + + print( + "`clang-format` not found on PATH; generated files will be emitted " + "without formatting.", + file=sys.stderr, + ) + + return None + + +def _clang_format(text: str, path: pathlib.Path) -> str: + """Pipe `text` through `clang-format` so generated headers / sources + satisfy the same style check (`clang-format` v21) that CI runs. + `path` informs include sorting (the file's own header should come + first in a `.cc`). If no `clang-format` binary is available, return + the input unchanged.""" + + if _CLANG_FORMAT is None: + return text + + return subprocess.run( + [_CLANG_FORMAT, f"--assume-filename={path}"], + input=text, + capture_output=True, + text=True, + check=True, + ).stdout + + +def _emit(name: str, ops: list[Op], *, emit_base: bool) -> None: + base_path = _GENERATED_BASE_DIR / f"{name}.h" + torch_dir = _GENERATED_TORCH_DIR / name + torch_header_path = torch_dir / f"{name}.h" + torch_source_path = torch_dir / f"{name}.cc" + + if emit_base: + _GENERATED_BASE_DIR.mkdir(parents=True, exist_ok=True) + base_path.write_text(_clang_format(_generate_base_header(name, ops), base_path)) + + torch_dir.mkdir(parents=True, exist_ok=True) + + torch_header_path.write_text( + _clang_format(_generate_torch_header(name, ops), torch_header_path) + ) + torch_source_path.write_text( + _clang_format(_generate_torch_source(name, ops), torch_source_path) + ) + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__.splitlines()[0]) + parser.add_argument( + "--ops", + nargs="*", + help="Override the op allowlist. If omitted, reads `scripts/torch_ops.yaml`.", + ) + parser.add_argument( + "--pytorch-version", + default=os.environ.get("INFINIOPS_PYTORCH_VERSION", _DEFAULT_PYTORCH_VERSION), + help=( + "PyTorch version label used in diagnostics. Schemas are read from " + "the locally installed `torchgen` package. Default: `%(default)s`. " + "Can also be set via the `INFINIOPS_PYTORCH_VERSION` environment " + "variable." + ), + ) + args = parser.parse_args() + + global _CLANG_FORMAT + _CLANG_FORMAT = _find_clang_format() + + op_names = args.ops or yaml.safe_load(_OPS_YAML_PATH.read_text()) + aten_entries = yaml.safe_load(_load_aten_yaml(args.pytorch_version)) + + # Wipe previous outputs so files for ops that have since been removed, + # renamed, or rejected by `cpp_type` don't linger and get picked up by + # the CMake glob. Both `generated/base/` and `generated/torch/` are + # written exclusively by this script. + if _GENERATED_BASE_DIR.exists(): + shutil.rmtree(_GENERATED_BASE_DIR) + + if _GENERATED_TORCH_DIR.exists(): + shutil.rmtree(_GENERATED_TORCH_DIR) + + skipped: list[tuple[str, str]] = [] + metadata: list[dict] = [] + + for name in op_names: + candidates = _find_out_entries(aten_entries, name) + + if not candidates: + skipped.append((name, f"no `.out` variant for `{name}` in YAML")) + continue + + usable: list[Op] = [] + last_reason = "" + + for entry in candidates: + try: + op = _parse_func(entry["func"]) + + for param in op.params: + _ = param.cpp_type # eagerly raise on unsupported types + except (NotImplementedError, ValueError) as exc: + last_reason = str(exc) + continue + + if not op.is_testable: + last_reason = "no testable tensor input/output pair" + continue + + usable.append(op) + + if not usable: + skipped.append((name, last_reason or "no usable overload")) + continue + + usable, duplicate_overloads = _dedupe_visible_overloads(usable) + + for skipped_op, kept_op in duplicate_overloads: + skipped.append( + ( + skipped_op.infini_name, + "duplicate visible C++ signature for " + f"`{name}`; using `{kept_op.infini_name}`", + ) + ) + + # Emit one InfiniOps wrapper per ATen op. Distinct visible overloads + # become overloaded constructors / `operator()` methods on the same + # class (`Pow` exposes both tensor and scalar exponents). Overloads + # that collapse to the same C++ signature after hidden defaults are + # skipped above. When a hand-written `src/base/.h` exists, + # skip emitting `generated/base/.h` so the hand-written one + # wins (the generated torch source's `#include "base/.h"` + # resolves through `src/` first). Signature mismatches surface as + # compile errors with a clear message — drop the op from the YAML + # to suppress. + public_name = usable[0].infini_name + _emit(public_name, usable, emit_base=not _base_path(public_name).exists()) + + for op in usable: + metadata.append( + { + "name": public_name, + "aten_name": op.aten_name, + "overload_name": op.infini_name, + "params": [ + { + "name": p.name, + "type": p.aten_type, + "is_tensor": p.is_tensor, + "is_out": p.is_out, + } + for p in op.visible_params + ], + } + ) + + _GENERATED_DIR.mkdir(parents=True, exist_ok=True) + _METADATA_PATH.write_text(json.dumps({"ops": metadata}, indent=2) + "\n") + + generated_names = sorted({m["name"] for m in metadata}) + print( + f"generated {len(metadata)} overloads across {len(generated_names)} ops: " + f"{generated_names}" + ) + + for name, reason in skipped: + print(f" skipped {name!r}: {reason}", file=sys.stderr) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/torch_ops.yaml b/scripts/torch_ops.yaml new file mode 100644 index 000000000..27c4ae487 --- /dev/null +++ b/scripts/torch_ops.yaml @@ -0,0 +1,540 @@ +# Allowlist of ATen ops to expose as InfiniOps operators. +# +# Auto-discovered: every base op name with at least one parsable +# `.out` overload using the supported type vocabulary. The +# generator emits one InfiniOps wrapper per overload, so this +# file lists 500+ base names and produces 500+ wrappers. +# +# To exclude an op, comment out its line. Ops whose hand-written +# `src/base/.h` signature does not match the ATen-derived one +# (currently `add`, `linear`, `matmul`, `mul` — they pre-date this +# codegen and use a different parameter shape) must stay excluded: +# the generator skips emitting their base, but would still emit a +# torch backend declaring `operator()` with the ATen signature, and +# that override would not compile against the hand-written base. + +- abs +- absolute +- acos +- acosh +- adaptive_avg_pool2d +- adaptive_avg_pool3d +- adaptive_avg_pool3d_backward +- adaptive_max_pool2d +- adaptive_max_pool2d_backward +- adaptive_max_pool3d +- adaptive_max_pool3d_backward +- addbmm +- addcdiv +- addcmul +- addmm +- addmv +- addr +- all +- amax +- amin +- aminmax +- angle +- any +- arange +- arccos +- arccosh +- arcsin +- arcsinh +- arctan +- arctan2 +- arctanh +- argmax +- argmin +- asin +- asinh +- atan +- atan2 +- atanh +- avg_pool2d +- avg_pool2d_backward +- avg_pool3d +- avg_pool3d_backward +- baddbmm +- batch_norm_elemt +- bernoulli +- binary_cross_entropy +- binary_cross_entropy_backward +- bitwise_and +- bitwise_left_shift +- bitwise_not +- bitwise_or +- bitwise_right_shift +- bitwise_xor +- bmm +- bucketize +- ceil +- cholesky +- cholesky_inverse +- cholesky_solve +- clamp +- clamp_max +- clamp_min +- clip +- col2im +- complex +- conj_physical +- copysign +- cos +- cosh +- cross +- cudnn_convolution +- cummax +- cummin +- cumprod +- cumsum +- deg2rad +- diag +- diff +- digamma +- div +- divide +- dot +- elu +- elu_backward +- empty +- eq +- erf +- erfc +- erfinv +- exp +- exp2 +- expm1 +- eye +- fft_fft +- fft_fft2 +- fft_fftfreq +- fft_fftn +- fft_hfft +- fft_hfft2 +- fft_hfftn +- fft_ifft +- fft_ifft2 +- fft_ifftn +- fft_ihfft +- fft_ihfft2 +- fft_ihfftn +- fft_irfft +- fft_irfft2 +- fft_irfftn +- fft_rfft +- fft_rfft2 +- fft_rfftfreq +- fft_rfftn +- fix +- float_power +- floor +- floor_divide +- fmax +- fmin +- fmod +- frac +- fractional_max_pool2d +- fractional_max_pool2d_backward +- fractional_max_pool3d +- fractional_max_pool3d_backward +- frexp +- frobenius_norm +- full +- gather +- gcd +- ge +- gelu +- gelu_backward +- geqrf +- ger +- glu +- glu_backward +- greater +- greater_equal +- gt +- hardshrink +- hardshrink_backward +- hardsigmoid +- hardsigmoid_backward +- hardswish +- hardtanh +- hardtanh_backward +- heaviside +- histc +- histogram +- hspmm +- huber_loss +- huber_loss_backward +- hypot +- i0 +- igamma +- igammac +- im2col +- index +- index_add +- index_copy +- index_reduce +- index_select +- inner +- inverse +- isin +- isneginf +- isposinf +- kron +- kthvalue +- lcm +- ldexp +- le +- leaky_relu +- leaky_relu_backward +- lerp +- less +- less_equal +- lgamma +- linalg_cholesky +- linalg_cholesky_ex +- linalg_cond +- linalg_cross +- linalg_det +- linalg_eig +- linalg_eigh +- linalg_eigvals +- linalg_eigvalsh +- linalg_householder_product +- linalg_inv +- linalg_inv_ex +- linalg_ldl_factor +- linalg_ldl_factor_ex +- linalg_ldl_solve +- linalg_lstsq +- linalg_lu +- linalg_lu_factor +- linalg_lu_factor_ex +- linalg_lu_solve +- linalg_matmul +- linalg_matrix_norm +- linalg_matrix_power +- linalg_matrix_rank +- linalg_norm +- linalg_pinv +- linalg_qr +- linalg_slogdet +- linalg_solve +- linalg_solve_ex +- linalg_solve_triangular +- linalg_svd +- linalg_svdvals +- linalg_tensorinv +- linalg_tensorsolve +- linalg_vecdot +- linalg_vector_norm +- linspace +- log +- log10 +- log1p +- log2 +- log_sigmoid +- log_sigmoid_backward +- log_sigmoid_forward +- log_softmax +- logaddexp +- logaddexp2 +- logcumsumexp +- logical_and +- logical_not +- logical_or +- logical_xor +- logit +- logit_backward +- logspace +- logsumexp +- lt +- lu_solve +- lu_unpack +- masked_select +- matrix_power +- max +- max_pool2d_with_indices +- max_pool2d_with_indices_backward +- max_pool3d_with_indices +- max_pool3d_with_indices_backward +- max_unpool2d +- max_unpool3d +- maximum +- mean +- median +- min +- minimum +- mish +- mkldnn_adaptive_avg_pool2d +- mm +- mode +- mse_loss +- mse_loss_backward +- msort +- multi_margin_loss +- multi_margin_loss_backward +- multilabel_margin_loss +- multilabel_margin_loss_backward +- multilabel_margin_loss_forward +- multinomial +- multiply +- mv +- mvlgamma +- nan_to_num +- nanmean +- nanmedian +- nanquantile +- nansum +- narrow_copy +- native_batch_norm +- ne +- neg +- negative +- nextafter +- nll_loss +- nll_loss2d +- nll_loss2d_backward +- nll_loss2d_forward +- nll_loss_backward +- nll_loss_forward +- nonzero +- nonzero_static +- norm +- normal +- not_equal +- nuclear_norm +- ones +- orgqr +- ormqr +- outer +- polar +- polygamma +- pow +- prod +- qr +- quantile +- rad2deg +- rand +- randint +- randn +- randperm +- range +- reciprocal +- reflection_pad1d +- reflection_pad1d_backward +- reflection_pad2d +- reflection_pad2d_backward +- reflection_pad3d +- reflection_pad3d_backward +- remainder +- renorm +- replication_pad1d +- replication_pad1d_backward +- replication_pad2d +- replication_pad2d_backward +- replication_pad3d +- replication_pad3d_backward +- round +- rrelu_with_noise +- rsqrt +- scatter +- scatter_add +- scatter_reduce +- searchsorted +- sgn +- sigmoid +- sigmoid_backward +- sign +- signbit +- silu +- silu_backward +- sin +- sinc +- sinh +- slogdet +- slow_conv3d +- slow_conv3d_forward +- slow_conv_transpose2d +- slow_conv_transpose3d +- smooth_l1_loss +- smooth_l1_loss_backward +- soft_margin_loss +- soft_margin_loss_backward +- softmax +- softplus +- softplus_backward +- softshrink +- softshrink_backward +- sort +- sparse_sampled_addmm +- special_airy_ai +- special_bessel_j0 +- special_bessel_j1 +- special_bessel_y0 +- special_bessel_y1 +- special_chebyshev_polynomial_t +- special_chebyshev_polynomial_u +- special_chebyshev_polynomial_v +- special_chebyshev_polynomial_w +- special_digamma +- special_entr +- special_erf +- special_erfc +- special_erfcx +- special_erfinv +- special_exp2 +- special_expit +- special_expm1 +- special_gammainc +- special_gammaincc +- special_gammaln +- special_hermite_polynomial_h +- special_hermite_polynomial_he +- special_i0 +- special_i0e +- special_i1 +- special_i1e +- special_laguerre_polynomial_l +- special_legendre_polynomial_p +- special_log1p +- special_log_ndtr +- special_logit +- special_logsumexp +- special_modified_bessel_i0 +- special_modified_bessel_i1 +- special_modified_bessel_k0 +- special_modified_bessel_k1 +- special_multigammaln +- special_ndtr +- special_ndtri +- special_polygamma +- special_psi +- special_round +- special_scaled_modified_bessel_k0 +- special_scaled_modified_bessel_k1 +- special_shifted_chebyshev_polynomial_t +- special_shifted_chebyshev_polynomial_u +- special_shifted_chebyshev_polynomial_v +- special_shifted_chebyshev_polynomial_w +- special_sinc +- special_spherical_bessel_j0 +- special_xlog1py +- special_xlogy +- special_zeta +- split_copy +- split_with_sizes_copy +- sqrt +- square +- sspaddmm +- std +- sub +- subtract +- sum +- svd +- take +- take_along_dim +- tan +- tanh +- tanh_backward +- tensordot +- thnn_conv2d +- threshold +- threshold_backward +- topk +- triangular_solve +- tril +- triu +- true_divide +- trunc +- unbind_copy +- upsample_bicubic2d +- upsample_bicubic2d_backward +- upsample_bilinear2d +- upsample_bilinear2d_backward +- upsample_linear1d +- upsample_linear1d_backward +- upsample_nearest1d +- upsample_nearest1d_backward +- upsample_nearest2d +- upsample_nearest2d_backward +- upsample_nearest3d +- upsample_nearest3d_backward +- upsample_trilinear3d +- upsample_trilinear3d_backward +- var +- vdot +- where +- xlogy +- zeros +- _add_relu +- _addmm_activation +- _batch_norm_with_update +- _conv_depthwise2d +- _convert_indices_from_coo_to_csr +- _convert_indices_from_csr_to_coo +- _fft_c2c +- _fft_c2r +- _fft_r2c +- _int_mm +- _linalg_det +- _linalg_eigh +- _linalg_slogdet +- _linalg_solve_ex +- _linalg_svd +- _log_softmax +- _logcumsumexp +- _scaled_mm +- _slow_conv2d_backward +- _slow_conv2d_forward +- _softmax +- _upsample_bicubic2d_aa +- _upsample_bicubic2d_aa_backward +- _upsample_bilinear2d_aa +- _upsample_bilinear2d_aa_backward +- _upsample_nearest_exact1d +- _upsample_nearest_exact1d_backward +- _upsample_nearest_exact2d +- _upsample_nearest_exact2d_backward +- _upsample_nearest_exact3d +- _upsample_nearest_exact3d_backward +- add_ +- argsort +- bernoulli_ +- bitwise_and_ +- bitwise_left_shift_ +- bitwise_or_ +- bitwise_right_shift_ +- bitwise_xor_ +- clamp_max_ +- clamp_min_ +- copysign_ +- div_ +- divide_ +- eq_ +- fill_ +- float_power_ +- floor_divide_ +- fmod_ +- ge_ +- greater_ +- greater_equal_ +- gt_ +- le_ +- lerp_ +- less_ +- less_equal_ +- lt_ +- masked_fill_ +- mul_ +- multiply_ +- ne_ +- not_equal_ +- pow_ +- remainder_ +- set_ +- sub_ +- subtract_ +- true_divide_ +- xlogy_ From d7319fbfdb1bacd2aa10c55190f85a39b2e44385 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Sat, 16 May 2026 09:41:39 +0800 Subject: [PATCH 3/4] build(torch): integrate generated torch backend --- CMakeLists.txt | 62 +++++- scripts/generate_wrappers.py | 408 +++++++++++++++++++++++++++-------- src/CMakeLists.txt | 225 ++++++++++++++++++- 3 files changed, 590 insertions(+), 105 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 91c2b0154..9973438cf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -92,31 +92,71 @@ endif() if(AUTO_DETECT_BACKENDS) message(STATUS "Auto-detecting available backends...") + # The Python that scikit-build's build-isolated environment hands + # us does not have `torch` (only `[build-system].requires` is + # installed). Fall back to a list of common system interpreters so + # the auto-detection finds `torch` when it is in the install env + # but not the build env. The first interpreter that successfully + # imports `torch` wins and is reused by the `WITH_TORCH` block + # below for include / library lookups. find_package(Python COMPONENTS Interpreter QUIET) - if(Python_FOUND) + set(_torch_python_candidates "${Python_EXECUTABLE}") + foreach(_candidate + python3 + python + /usr/bin/python3 + /usr/local/bin/python3 + /opt/conda/bin/python + /opt/conda/bin/python3) + find_program(_resolved_${_candidate} ${_candidate}) + if(_resolved_${_candidate} AND + NOT _resolved_${_candidate} STREQUAL "${Python_EXECUTABLE}") + list(APPEND _torch_python_candidates "${_resolved_${_candidate}}") + endif() + endforeach() + + foreach(_py ${_torch_python_candidates}) + if(NOT _py) + continue() + endif() + execute_process( - COMMAND ${Python_EXECUTABLE} -c "import torch" + COMMAND "${_py}" -c "import torch" RESULT_VARIABLE _torch_import_result OUTPUT_QUIET ERROR_QUIET ) if(_torch_import_result EQUAL 0) - set(WITH_TORCH ON) - message(STATUS "Auto-detected PyTorch.") + set(_TORCH_PYTHON "${_py}") + break() endif() + endforeach() + + if(_TORCH_PYTHON) + set(WITH_TORCH ON) + message(STATUS "Auto-detected PyTorch (via ${_TORCH_PYTHON}).") endif() endif() if(WITH_TORCH) find_package(Python COMPONENTS Interpreter REQUIRED) + # Prefer the interpreter that the auto-detect block already + # confirmed has `torch` (this is the system Python on hosts that + # use scikit-build's build-isolation, where the build interpreter + # does not have `torch`). Fall back to `Python_EXECUTABLE` for + # explicit `-DWITH_TORCH=ON` invocations. + if(NOT _TORCH_PYTHON) + set(_TORCH_PYTHON "${Python_EXECUTABLE}") + endif() + # Query `torch` paths directly instead of using `find_package(Torch)`, # which pulls in Caffe2's CMake config and may fail on platforms with # non-standard CUDA toolchains. execute_process( - COMMAND ${Python_EXECUTABLE} -c "from torch.utils.cpp_extension import include_paths; print(';'.join(include_paths()))" + COMMAND ${_TORCH_PYTHON} -c "from torch.utils.cpp_extension import include_paths; print(';'.join(include_paths()))" OUTPUT_VARIABLE TORCH_INCLUDE_DIRS OUTPUT_STRIP_TRAILING_WHITESPACE RESULT_VARIABLE _torch_result @@ -127,7 +167,7 @@ if(WITH_TORCH) endif() execute_process( - COMMAND ${Python_EXECUTABLE} -c "from torch.utils.cpp_extension import library_paths; print(';'.join(library_paths()))" + COMMAND ${_TORCH_PYTHON} -c "from torch.utils.cpp_extension import library_paths; print(';'.join(library_paths()))" OUTPUT_VARIABLE _torch_lib_dirs OUTPUT_STRIP_TRAILING_WHITESPACE ) @@ -144,7 +184,7 @@ if(WITH_TORCH) # the bundled `NEEDED` entries (otherwise: `undefined reference to # _gfortran_etime@GFORTRAN_8` etc.). execute_process( - COMMAND ${Python_EXECUTABLE} -c "import os, torch; d = os.path.dirname(torch.__file__); p = os.path.join(os.path.dirname(d), 'torch.libs'); print(p if os.path.isdir(p) else '')" + COMMAND ${_TORCH_PYTHON} -c "import os, torch; d = os.path.dirname(torch.__file__); p = os.path.join(os.path.dirname(d), 'torch.libs'); print(p if os.path.isdir(p) else '')" OUTPUT_VARIABLE TORCH_BUNDLED_LIBS_DIR OUTPUT_STRIP_TRAILING_WHITESPACE ) @@ -163,7 +203,7 @@ if(WITH_TORCH) # A mismatch causes linker errors (e.g. undefined reference to # `c10::Device::Device(std::string const&)`). execute_process( - COMMAND ${Python_EXECUTABLE} -c "import torch; print(int(torch.compiled_with_cxx11_abi()))" + COMMAND ${_TORCH_PYTHON} -c "import torch; print(int(torch.compiled_with_cxx11_abi()))" OUTPUT_VARIABLE TORCH_CXX11_ABI OUTPUT_STRIP_TRAILING_WHITESPACE RESULT_VARIABLE _torch_abi_result @@ -314,10 +354,12 @@ if(NOT WITH_NVIDIA AND NOT WITH_ILUVATAR AND NOT WITH_METAX AND NOT WITH_MOORE A add_compile_definitions(WITH_CPU=1) endif() -if(WITH_METAX OR WITH_MOORE) +if(WITH_TORCH OR WITH_METAX OR WITH_MOORE) set(PYBIND11_ENABLE_EXTRAS OFF) endif() add_subdirectory(src) -add_subdirectory(examples) +if(NOT GENERATE_PYTHON_BINDINGS) + add_subdirectory(examples) +endif() diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index 7b9a34286..f2ff37065 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -1,4 +1,6 @@ import argparse +import concurrent.futures +import functools import json import os import pathlib @@ -16,6 +18,11 @@ _GENERATION_DIR = pathlib.Path("generated") +# Base headers emitted by `generate_torch_ops.py` live alongside the +# hand-written ones in `src/base/`, but in a parallel tree under +# `generated/base/` so they are not committed. +_GENERATED_BASE_DIR = _GENERATION_DIR / "base" + _BINDINGS_DIR = _GENERATION_DIR / "bindings" _GENERATED_SRC_DIR = _GENERATION_DIR / "src" @@ -25,37 +32,61 @@ _INDENTATION = " " -class _OperatorExtractor: - def __call__(self, op_name): - def _get_system_include_flags(): - def _get_compilers(): - compilers = [] +@functools.lru_cache(maxsize=1) +def _get_system_include_flags(): + """Probe the system C++ compiler for default include paths so libclang + can resolve standard headers when parsing an op's base header.""" + compilers = [] + + for compiler in ("clang++", "g++"): + if shutil.which(compiler) is not None: + compilers.append(compiler) + + system_include_flags = [] + + for compiler in compilers: + for line in subprocess.getoutput( + f"{compiler} -E -x c++ -v /dev/null" + ).splitlines(): + if not line.startswith(" "): + continue + + system_include_flags.append("-isystem") + system_include_flags.append(line.strip()) - for compiler in ("clang++", "g++"): - if shutil.which(compiler) is not None: - compilers.append(compiler) + return tuple(system_include_flags) - return compilers - system_include_flags = [] +def _find_base_header(op_name): + """Resolve the base header for `op_name`, preferring the hand-written + `src/base/.h` over the auto-generated `generated/base/.h`. + Mirrors the include-path resolution order used at compile time.""" + src_path = _BASE_DIR / f"{op_name}.h" - for compiler in _get_compilers(): - for line in subprocess.getoutput( - f"{compiler} -E -x c++ -v /dev/null" - ).splitlines(): - if not line.startswith(" "): - continue + if src_path.exists(): + return src_path - system_include_flags.append("-isystem") - system_include_flags.append(line.strip()) + generated_path = _GENERATED_BASE_DIR / f"{op_name}.h" - return system_include_flags + if generated_path.exists(): + return generated_path - system_include_flags = _get_system_include_flags() + raise FileNotFoundError(f"no base header for op {op_name!r}") + +class _OperatorExtractor: + def __call__(self, op_name): index = clang.cindex.Index.create() - args = ("-std=c++17", "-x", "c++", "-I", "src") + tuple(system_include_flags) - translation_unit = index.parse(f"src/base/{op_name}.h", args=args) + args = ( + "-std=c++17", + "-x", + "c++", + "-I", + "src", + "-I", + str(_GENERATION_DIR), + ) + _get_system_include_flags() + translation_unit = index.parse(str(_find_base_header(op_name)), args=args) nodes = tuple(type(self)._find(translation_unit.cursor, op_name)) @@ -99,7 +130,7 @@ def _find_optional_tensor_params(op_name): headers are not fully available, so we fall back to a regex scan of the source text. """ - source = (_BASE_DIR / f"{op_name}.h").read_text() + source = _find_base_header(op_name).read_text() return set(re.findall(r"std::optional\s+(\w+)", source)) @@ -108,14 +139,31 @@ def _find_vector_tensor_params(op_name): """Return a set of parameter names declared as `std::vector` in the base header. """ - source = (_BASE_DIR / f"{op_name}.h").read_text() + source = _find_base_header(op_name).read_text() return set(re.findall(r"std::vector\s+(\w+)", source)) +def _find_vector_int64_params(op_name): + """Return a set of parameter names declared as `std::vector` in + the base header. + + libclang on systems where the STL headers are not fully indexable + silently falls back to reporting the type as `int` for these params, + which then leaks into the generated bindings as `const int padding` + instead of `const std::vector padding` and breaks the call + to the base operator. Regex-scan the source so the binding's + parameter type comes from the actual declaration. + """ + source = _find_base_header(op_name).read_text() + + return set(re.findall(r"std::vector\s+(\w+)", source)) + + def _generate_pybind11(operator): optional_tensor_params = _find_optional_tensor_params(operator.name) vector_tensor_params = _find_vector_tensor_params(operator.name) + vector_int64_params = _find_vector_int64_params(operator.name) def _is_optional_tensor(arg): if arg.spelling in optional_tensor_params: @@ -132,6 +180,9 @@ def _is_vector_tensor(arg): return "std::vector" in arg.type.spelling and "Tensor" in arg.type.spelling + def _is_vector_int64(arg): + return arg.spelling in vector_int64_params + def _generate_params(node): parts = [] @@ -143,6 +194,8 @@ def _generate_params(node): parts.append(f"std::optional {arg.spelling}") elif _is_vector_tensor(arg): parts.append(f"std::vector {arg.spelling}") + elif _is_vector_int64(arg): + parts.append(f"const std::vector {arg.spelling}") else: param = arg.type.spelling.replace("const Tensor", "py::object").replace( "Tensor", "py::object" @@ -174,6 +227,7 @@ def _generate_arguments(node): def _generate_init(constructor): constructor_params = _generate_params(constructor) + return f""" .def(py::init([]({constructor_params}) {{ Config config; return std::unique_ptr{{static_cast(generated_dispatch::Make{pascal_case_op_name}(config, {_generate_arguments(constructor)}).release())}}; @@ -196,6 +250,7 @@ def _generate_py_args(node): def _generate_call(op_name, call, method=True): call_params = _generate_params(call) call_args = _generate_arguments(call) + if not method: params = ( f"{call_params}, std::uintptr_t stream, std::size_t implementation_index" @@ -217,16 +272,52 @@ def _generate_call(op_name, call, method=True): f' }}, {py_args_str}py::kw_only(), py::arg("stream") = 0, py::arg("implementation_index") = 0);' ) - return f""" .def("__call__", [](const Self& self, {call_params}) {{ - return generated_dispatch::Invoke{pascal_case_op_name}(self, {call_args}); + # The first lambda parameter is conventionally named `self`, but + # ATen schemas often have a parameter literally called `self` + # (e.g. `pow.Tensor_Scalar_out(Scalar self, Tensor exponent)`), + # so rename to `op` to avoid the collision in the generated code. + + return f""" .def("__call__", [](const Self& op, {call_params}) {{ + return generated_dispatch::Invoke{pascal_case_op_name}(op, {call_args}); }})""" - inits = "\n".join( - _generate_init(constructor) for constructor in operator.constructors - ) - calls = "\n".join(_generate_call(operator.name, call) for call in operator.calls) + def _overload_order_key(node): + """Sort key that places more-specific overloads first. + + Tensor parameters are exposed to pybind as `py::object`, which + accepts any Python value and only fails inside + `TensorFromPybind11Handle`. When a class has both Tensor and + scalar overloads, pybind's overload-resolver tries them in + registration order and stops at the first that does not raise, + so the scalar overload must be registered first; otherwise the + permissive Tensor signature swallows scalar calls and aborts at + runtime. + """ + object_like = 0 + total = 0 + + for arg in node.get_arguments(): + if arg.spelling == "stream": + continue + + total += 1 + + if ( + _is_optional_tensor(arg) + or _is_vector_tensor(arg) + or "Tensor" in arg.type.spelling + ): + object_like += 1 + + return (object_like, -total) + + constructors = sorted(operator.constructors, key=_overload_order_key) + operator_calls = sorted(operator.calls, key=_overload_order_key) + + inits = "\n".join(_generate_init(constructor) for constructor in constructors) + calls = "\n".join(_generate_call(operator.name, call) for call in operator_calls) callers = "\n".join( - _generate_call(operator.name, call, method=False) for call in operator.calls + _generate_call(operator.name, call, method=False) for call in operator_calls ) return f"""#ifndef INFINI_OPS_BINDINGS_{op_name.upper()}_H_ @@ -252,7 +343,11 @@ def _generate_call(op_name, call, method=True): {inits} {calls} .def_static("active_implementation_indices", [](const std::string& device) {{ - return generated_dispatch::ActiveImplementationIndicesFor{pascal_case_op_name}(DeviceTypeFromString(device)); + auto dev_type = TryDeviceTypeFromString(device); + if (!dev_type.has_value()) {{ + return std::vector{{}}; + }} + return generated_dispatch::ActiveImplementationIndicesFor{pascal_case_op_name}(*dev_type); }}) .def_static("clear_cache", &generated_dispatch::ClearCacheFor{pascal_case_op_name}); @@ -268,7 +363,7 @@ def _generate_call(op_name, call, method=True): def _generate_legacy_c(operator, paths): def _generate_source(operator): impl_includes = "\n".join( - f'#include "{str(path).removeprefix("src/")}"' for path in paths + f'#include "{_to_include_path(path)}"' for path in paths ) return f"""#include "../../handle.h" @@ -411,6 +506,7 @@ def _generate_params(node, call=False): def _handle_tensor(spelling): if call: return spelling.replace("Tensor", "void *") + return spelling.replace("Tensor", "infiniopTensorDescriptor_t") def _handle_std_optional(spelling): @@ -524,9 +620,9 @@ def _append_optional_params(prefix, params): return declarations, definitions -def _generate_generated_dispatch_header(operators, devices, declarations): +def _generate_generated_dispatch_header(op_names, devices, declarations): header_base_includes = "\n".join( - f'#include "base/{operator.name}.h"' for operator in operators + f'#include "base/{op_name}.h"' for op_name in op_names ) header_device_includes = "\n".join( f'#include "{path}"' for path in _device_marker_headers(devices) @@ -576,18 +672,6 @@ def _generate_generated_dispatch_source(impl_paths, definitions): """ -def _dispatch_gen_batch_size(): - raw = os.environ.get("INFINIOPS_DISPATCH_BATCH_SIZE") - - if raw: - try: - return max(1, int(raw)) - except ValueError: - return 8 - - return 8 - - def _device_marker_headers(devices): paths = { "cpu": "native/cpu/device_.h", @@ -611,6 +695,46 @@ def _snake_to_pascal(snake_str): return "".join(word.capitalize() for word in snake_str.split("_")) +def _to_include_path(path): + text = str(path) + + for prefix in ("src/", "generated/"): + if text.startswith(prefix): + return text[len(prefix) :] + + return text + + +def _matches_scan_dir(impl_path, scan_dirs): + return any(part in scan_dirs for part in impl_path.parts) + + +_OPERATOR_DECL_RE = re.compile(r"\bclass\s+Operator<\s*([A-Za-z_][A-Za-z0-9_]*)\b") + + +def _index_impl_headers(impl_roots, scan_dirs): + """Index implementation headers by base operator class name. + + The previous implementation scanned every implementation header once per + operator. With the generated PyTorch backend enabled this becomes hundreds + of ops times hundreds of headers during CMake configure. Read each header + once instead and keep the same insertion order as the old nested loops. + """ + by_operator = {} + + for impl_root in impl_roots: + for impl_path in impl_root.rglob("*.h"): + if not _matches_scan_dir(impl_path, scan_dirs): + continue + + text = impl_path.read_text() + + for match in _OPERATOR_DECL_RE.finditer(text): + by_operator.setdefault(match.group(1), []).append(impl_path) + + return by_operator + + def _get_all_ops(devices, with_torch=False): scan_dirs = set(devices) @@ -619,24 +743,100 @@ def _get_all_ops(devices, with_torch=False): ops = {} - for file_path in _BASE_DIR.iterdir(): - if not file_path.is_file(): - continue + base_dirs = [_BASE_DIR] + + # Only pull in the auto-generated torch op bases when the build is + # actually compiling them (`--with-torch`). Otherwise a stale + # `generated/` left over from a previous configure (or rsynced into + # a CI container) would cause `ops.cc` to include base headers for + # ops that have no compiled implementation, breaking the build. + if with_torch and _GENERATED_BASE_DIR.exists(): + base_dirs.append(_GENERATED_BASE_DIR) + + impl_roots = [_SRC_DIR] - op_name = file_path.stem + if with_torch and (_GENERATION_DIR / "torch").exists(): + impl_roots.append(_GENERATION_DIR) - ops[op_name] = [] + impl_headers_by_operator = _index_impl_headers(impl_roots, scan_dirs) - for file_path in _SRC_DIR.rglob("*.h"): - if file_path.parent.parent.parent.name not in scan_dirs: + for base_dir in base_dirs: + for file_path in base_dir.iterdir(): + if not file_path.is_file(): continue - if f"class Operator<{_snake_to_pascal(op_name)}" in file_path.read_text(): - ops[op_name].append(file_path) + op_name = file_path.stem + + # Hand-written `src/base/` is scanned first; the generated + # tree never overrides an already-known op. + if op_name in ops: + continue + + ops[op_name] = [] + ops[op_name].extend( + impl_headers_by_operator.get(_snake_to_pascal(op_name), ()) + ) return ops +def _generate_op_artifacts(item): + op_name, impl_paths = item + extractor = _OperatorExtractor() + operator = extractor(op_name) + header_name = f"{op_name}.h" + legacy_c_source, legacy_c_header = _generate_legacy_c(operator, impl_paths) + dispatch_declarations, dispatch_definitions = _generate_generated_dispatch_entries( + operator + ) + + return { + "op_name": op_name, + "header_name": header_name, + "bind_func_name": f"Bind{_snake_to_pascal(op_name)}", + "pybind11": _generate_pybind11(operator), + "binding_source": _generate_binding_source(op_name), + "legacy_c_source": legacy_c_source, + "legacy_c_header": legacy_c_header, + "dispatch_declarations": dispatch_declarations, + "dispatch_definitions": dispatch_definitions, + "impl_paths": impl_paths, + } + + +def _wrapper_gen_jobs(with_torch): + raw = os.environ.get("INFINIOPS_WRAPPER_GEN_JOBS") + + if raw: + try: + return max(1, int(raw)) + except ValueError: + return 1 + + if not with_torch: + return 1 + + return min(os.cpu_count() or 1, 8) + + +def _use_monolithic_bindings(): + value = os.environ.get("INFINIOPS_MONOLITHIC_BINDINGS", "") + + return value.upper() in {"1", "ON", "TRUE"} + + +def _dispatch_gen_batch_size(): + raw = os.environ.get("INFINIOPS_DISPATCH_BATCH_SIZE") + + if raw: + try: + return max(1, int(raw)) + except ValueError: + return 8 + + return 8 + + if __name__ == "__main__": parser = argparse.ArgumentParser(description="An automatic wrapper generator.") @@ -656,6 +856,9 @@ def _get_all_ops(devices, with_torch=False): args = parser.parse_args() + # Wipe previous outputs so files for ops that have since been removed + # from the active set (e.g. when toggling `--with-torch`) do not linger + # and get globbed by a later build. for directory in (_BINDINGS_DIR, _GENERATED_SRC_DIR, _INCLUDE_DIR): if directory.exists(): shutil.rmtree(directory) @@ -670,61 +873,96 @@ def _get_all_ops(devices, with_torch=False): ops = _get_all_ops(args.devices, with_torch=args.with_torch) bind_func_names = [] - operators = [] - dispatch_declarations = [] - dispatch_batches = [] - for op_name, impl_paths in ops.items(): - extractor = _OperatorExtractor() - operator = extractor(op_name) - operators.append(operator) - declarations, definitions = _generate_generated_dispatch_entries(operator) + jobs = _wrapper_gen_jobs(args.with_torch) + if jobs == 1: + artifacts = [_generate_op_artifacts(item) for item in ops.items()] + else: + with concurrent.futures.ProcessPoolExecutor(max_workers=jobs) as executor: + artifacts = list(executor.map(_generate_op_artifacts, ops.items())) + + op_names = [artifact["op_name"] for artifact in artifacts] + dispatch_declarations = [ + declaration + for artifact in artifacts + for declaration in artifact["dispatch_declarations"] + ] + use_monolithic_bindings = _use_monolithic_bindings() + op_includes = [] + + for artifact in artifacts: + op_name = artifact["op_name"] source_path = _GENERATED_SRC_DIR / op_name - header_name = f"{op_name}.h" - bind_func_name = f"Bind{_snake_to_pascal(op_name)}" + header_name = artifact["header_name"] + bind_func_name = artifact["bind_func_name"] + + (_BINDINGS_DIR / header_name).write_text(artifact["pybind11"]) - (_BINDINGS_DIR / header_name).write_text(_generate_pybind11(operator)) - (_BINDINGS_DIR / f"{op_name}.cc").write_text(_generate_binding_source(op_name)) + if use_monolithic_bindings: + op_includes.append(f'#include "{header_name}"') + else: + (_BINDINGS_DIR / f"{op_name}.cc").write_text(artifact["binding_source"]) - legacy_c_source, legacy_c_header = _generate_legacy_c(operator, impl_paths) source_path.mkdir(exist_ok=True) - (_GENERATED_SRC_DIR / op_name / "operator.cc").write_text(legacy_c_source) - (_INCLUDE_DIR / header_name).write_text(legacy_c_header) + (_GENERATED_SRC_DIR / op_name / "operator.cc").write_text( + artifact["legacy_c_source"] + ) + (_INCLUDE_DIR / header_name).write_text(artifact["legacy_c_header"]) bind_func_names.append(bind_func_name) - dispatch_declarations.extend(declarations) - dispatch_batches.append((impl_paths, definitions)) dispatch_header = _generate_generated_dispatch_header( - operators, args.devices, dispatch_declarations + op_names, args.devices, dispatch_declarations ) (_BINDINGS_DIR / "generated_dispatch.h").write_text(dispatch_header) dispatch_batch_size = _dispatch_gen_batch_size() for dispatch_batch_index, start in enumerate( - range(0, len(dispatch_batches), dispatch_batch_size) + range(0, len(artifacts), dispatch_batch_size) ): - batch = dispatch_batches[start : start + dispatch_batch_size] + batch = artifacts[start : start + dispatch_batch_size] impl_paths = list( - dict.fromkeys(impl_path for paths, _ in batch for impl_path in paths) + dict.fromkeys( + impl_path for artifact in batch for impl_path in artifact["impl_paths"] + ) ) - definitions = [definition for _, defs in batch for definition in defs] + definitions = [ + definition + for artifact in batch + for definition in artifact["dispatch_definitions"] + ] dispatch_source = _generate_generated_dispatch_source(impl_paths, definitions) (_BINDINGS_DIR / f"generated_dispatch_{dispatch_batch_index}.cc").write_text( dispatch_source ) - bind_func_declarations = "\n".join( - f"void {bind_func_name}(pybind11::module& m);" - for bind_func_name in bind_func_names - ) bind_func_calls = "\n".join( f"{bind_func_name}(m);" for bind_func_name in bind_func_names ) - (_BINDINGS_DIR / "ops.cc").write_text(f"""#include + if use_monolithic_bindings: + op_includes = "\n".join(op_includes) + ops_source = f"""#include + +// Generated with `INFINIOPS_MONOLITHIC_BINDINGS=1`. +{op_includes} + +namespace infini::ops {{ + +PYBIND11_MODULE(ops, m) {{ +{textwrap.indent(bind_func_calls, _INDENTATION)} +}} + +}} // namespace infini::ops +""" + else: + bind_func_declarations = "\n".join( + f"void {bind_func_name}(pybind11::module& m);" + for bind_func_name in bind_func_names + ) + ops_source = f"""#include namespace infini::ops {{ @@ -735,4 +973,6 @@ def _get_all_ops(devices, with_torch=False): }} }} // namespace infini::ops -""") +""" + + (_BINDINGS_DIR / "ops.cc").write_text(ops_source) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 35d63ff85..762b9d48f 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -90,7 +90,7 @@ if(WITH_METAX) set_source_files_properties(${METAX_SOURCES} PROPERTIES LANGUAGE CXX) target_compile_definitions(infiniops PRIVATE WITH_METAX=1) - target_compile_options(infiniops PUBLIC "-x" "maca") + target_compile_options(infiniops PRIVATE "-x" "maca") target_sources(infiniops PRIVATE ${METAX_SOURCES}) target_include_directories(infiniops PUBLIC "${MACA_PATH}/include") @@ -117,7 +117,7 @@ if(WITH_MOORE) set_source_files_properties(${MOORE_SOURCES} PROPERTIES LANGUAGE CXX) target_compile_definitions(infiniops PRIVATE WITH_MOORE=1) - target_compile_options(infiniops PUBLIC "-x" "musa") + target_compile_options(infiniops PRIVATE "-x" "musa") target_sources(infiniops PRIVATE ${MOORE_SOURCES}) target_include_directories(infiniops PUBLIC "${MUSA_ROOT}/include") @@ -252,11 +252,112 @@ if(WITH_ASCEND) endif() if(WITH_TORCH) - file(GLOB_RECURSE TORCH_SOURCES CONFIGURE_DEPENDS "torch/*.cc" "torch/*.cpp") + # Auto-generate ATen-backed operator wrappers from `scripts/torch_ops.yaml`. + # The script writes into `${PROJECT_SOURCE_DIR}/generated/` (gitignored), + # which we then glob below alongside any hand-written torch sources. + find_package(Python COMPONENTS Interpreter REQUIRED) + + # Pin codegen to the locally installed torch version so vendor + # forks (Cambricon's `torch_mlu` 2.1.0, etc.) get a schema whose + # `at::_out` overloads match the headers they ship. Without + # this, the codegen targets v2.4.0 and the build fails on older + # forks with no-known-conversion errors (e.g. `at::all_out`'s + # `int64_t dim` vs `OptionalIntArrayRef dim`). + execute_process( + COMMAND ${_TORCH_PYTHON} -c + "import torch; print('v' + torch.__version__.split('+')[0])" + OUTPUT_VARIABLE _torch_version_tag + OUTPUT_STRIP_TRAILING_WHITESPACE + RESULT_VARIABLE _torch_version_result + ) + if(NOT _torch_version_result EQUAL 0 OR NOT _torch_version_tag) + set(_torch_version_tag "v2.4.0") + endif() + message(STATUS "Codegen schema: PyTorch ${_torch_version_tag}") + + set(INFINIOPS_TORCH_OPS "" CACHE STRING + "Semicolon-separated PyTorch op allowlist for generated torch wrappers") + set(_torch_codegen_args + ${PROJECT_SOURCE_DIR}/scripts/generate_torch_ops.py + --pytorch-version ${_torch_version_tag}) + if(INFINIOPS_TORCH_OPS) + string(REPLACE "," ";" _torch_op_allowlist "${INFINIOPS_TORCH_OPS}") + list(APPEND _torch_codegen_args --ops ${_torch_op_allowlist}) + message(STATUS "Codegen torch op allowlist: ${_torch_op_allowlist}") + endif() + + execute_process( + COMMAND ${Python_EXECUTABLE} ${_torch_codegen_args} + WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} + RESULT_VARIABLE _torch_ops_result + ) + if(NOT _torch_ops_result EQUAL 0) + message(FATAL_ERROR "Generating torch op wrappers - failed") + endif() + message(STATUS "Generating torch op wrappers - done") + + file(GLOB_RECURSE TORCH_SOURCES CONFIGURE_DEPENDS + "torch/*.cc" "torch/*.cpp" + "${PROJECT_SOURCE_DIR}/generated/torch/*.cc" + "${PROJECT_SOURCE_DIR}/generated/torch/*.cpp" + ) + + set(INFINIOPS_TORCH_UNITY_BATCH_SIZE "8" CACHE STRING + "Number of torch sources to include in each generated unity translation unit; set to 1 to disable") + set(TORCH_COMPILE_SOURCES ${TORCH_SOURCES}) + if(INFINIOPS_TORCH_UNITY_BATCH_SIZE GREATER 1) + set(_torch_unity_dir "${CMAKE_CURRENT_BINARY_DIR}/torch_unity") + file(REMOVE_RECURSE "${_torch_unity_dir}") + file(MAKE_DIRECTORY "${_torch_unity_dir}") + + set(TORCH_COMPILE_SOURCES) + set(_torch_unity_index 0) + set(_torch_unity_count 0) + foreach(_src IN LISTS TORCH_SOURCES) + if(_torch_unity_count EQUAL 0) + set(_torch_unity_src + "${_torch_unity_dir}/torch_unity_${_torch_unity_index}.cc") + file(WRITE "${_torch_unity_src}" + "// Generated by CMake to batch ATen-heavy torch wrappers.\n") + list(APPEND TORCH_COMPILE_SOURCES "${_torch_unity_src}") + endif() + + file(APPEND "${_torch_unity_src}" "#include \"${_src}\"\n") + + math(EXPR _torch_unity_count "${_torch_unity_count} + 1") + if(_torch_unity_count GREATER_EQUAL INFINIOPS_TORCH_UNITY_BATCH_SIZE) + math(EXPR _torch_unity_index "${_torch_unity_index} + 1") + set(_torch_unity_count 0) + endif() + endforeach() + + list(LENGTH TORCH_SOURCES _torch_source_count) + list(LENGTH TORCH_COMPILE_SOURCES _torch_unity_source_count) + message(STATUS + "Torch unity build: ${_torch_source_count} sources batched into " + "${_torch_unity_source_count} translation units") + endif() target_compile_definitions(infiniops PUBLIC WITH_TORCH=1) target_link_libraries(infiniops PUBLIC ${TORCH_LIBRARIES}) - target_include_directories(infiniops PUBLIC ${TORCH_INCLUDE_DIRS}) + target_include_directories(infiniops PUBLIC + ${TORCH_INCLUDE_DIRS} + ${PROJECT_SOURCE_DIR}/generated + ) + + # Each generated `.cc` instantiates `at::_out(...)`, which + # pulls in roughly 0.5-1 GB of ATen template metaprogramming. At + # ninja's default parallelism (one job per CPU), a build with 451 + # ops can blow past 30 GB of RSS and the OOM killer drops + # `cc1plus`. Cap the heavyweight torch sources via a Ninja job pool; + # the rest of the build keeps full parallelism. CI can raise this when + # the container has enough memory. + set(INFINIOPS_TORCH_COMPILE_JOBS "2" CACHE STRING + "Maximum concurrent generated torch source compilations") + if(CMAKE_GENERATOR MATCHES "Ninja") + set_property(GLOBAL APPEND PROPERTY JOB_POOLS + torch_compile=${INFINIOPS_TORCH_COMPILE_JOBS}) + endif() if(WITH_METAX OR WITH_MOORE) # Vendor compilers (`mxcc`/`mcc`) cannot compile vendor-forked `torch` @@ -275,7 +376,13 @@ if(WITH_TORCH) # Vendor-specific defines required by forked `torch` headers. set(_torch_extra_flags "") if(WITH_METAX) - list(APPEND _torch_extra_flags "-DUSE_MACA=1") + list(APPEND _torch_extra_flags "-DUSE_MACA=1" "-DWITH_METAX=1") + endif() + if(WITH_MOORE) + list(APPEND _torch_extra_flags "-DWITH_MOORE=1") + endif() + if(WITH_CPU) + list(APPEND _torch_extra_flags "-DWITH_CPU=1") endif() if(DEFINED TORCH_CXX11_ABI) list(APPEND _torch_extra_flags "-D_GLIBCXX_USE_CXX11_ABI=${TORCH_CXX11_ABI}") @@ -285,7 +392,7 @@ if(WITH_TORCH) file(MAKE_DIRECTORY "${TORCH_OBJECT_DIR}") set(TORCH_OBJECT_FILES) - foreach(_src ${TORCH_SOURCES}) + foreach(_src ${TORCH_COMPILE_SOURCES}) file(RELATIVE_PATH _rel ${CMAKE_CURRENT_SOURCE_DIR} ${_src}) string(REPLACE "/" "_" _obj_name "${_rel}") string(REPLACE ".cc" ".o" _obj_name "${_obj_name}") @@ -295,13 +402,15 @@ if(WITH_TORCH) add_custom_command( OUTPUT "${_obj}" COMMAND ${SYSTEM_CXX} - -std=c++17 -fPIC -O2 + -std=c++17 -fPIC -O0 "-I${CMAKE_CURRENT_SOURCE_DIR}" + "-I${PROJECT_SOURCE_DIR}/generated" ${_torch_include_flags} ${_torch_extra_flags} -c "${_src}" -o "${_obj}" DEPENDS "${_src}" COMMENT "Compiling ${_rel} with system C++ compiler" + JOB_POOL torch_compile ) list(APPEND TORCH_OBJECT_FILES "${_obj}") endforeach() @@ -310,7 +419,30 @@ if(WITH_TORCH) PROPERTIES EXTERNAL_OBJECT TRUE GENERATED TRUE) target_sources(infiniops PRIVATE ${TORCH_OBJECT_FILES}) else() - target_sources(infiniops PRIVATE ${TORCH_SOURCES}) + # Build the heavy torch sources as their own object library so + # the Ninja `torch_compile` job pool throttles only those + # compilations and the rest of `infiniops` keeps full + # parallelism. Inherit infiniops's compile-time settings via + # generator expressions (linking would create a cyclic + # dependency since infiniops then absorbs the object files). + add_library(infiniops_torch_objs OBJECT ${TORCH_COMPILE_SOURCES}) + set_target_properties(infiniops_torch_objs + PROPERTIES POSITION_INDEPENDENT_CODE ON) + target_include_directories(infiniops_torch_objs PRIVATE + $ + ${TORCH_INCLUDE_DIRS} + ${PROJECT_SOURCE_DIR}/generated) + target_compile_definitions(infiniops_torch_objs PRIVATE + $) + target_compile_options(infiniops_torch_objs PRIVATE + $ + -O0) + if(CMAKE_GENERATOR MATCHES "Ninja") + set_target_properties(infiniops_torch_objs + PROPERTIES JOB_POOL_COMPILE torch_compile) + endif() + target_sources(infiniops PRIVATE + $) endif() endif() @@ -343,9 +475,58 @@ if(GENERATE_PYTHON_BINDINGS) file(GLOB_RECURSE PYBIND11_SOURCES CONFIGURE_DEPENDS "${PROJECT_SOURCE_DIR}/generated/bindings/*.cc") + set(PYBIND11_DISPATCH_SOURCES) + set(PYBIND11_UNITY_SOURCES) + foreach(_src IN LISTS PYBIND11_SOURCES) + if(_src MATCHES "/generated_dispatch[^/]*\\.cc$") + list(APPEND PYBIND11_DISPATCH_SOURCES "${_src}") + else() + list(APPEND PYBIND11_UNITY_SOURCES "${_src}") + endif() + endforeach() + + set(PYBIND11_COMPILE_SOURCES ${PYBIND11_UNITY_SOURCES}) + if(WITH_TORCH) + set(INFINIOPS_BINDING_UNITY_BATCH_SIZE "8" CACHE STRING + "Number of generated pybind11 sources to include in each unity translation unit; set to 1 to disable") + if(INFINIOPS_BINDING_UNITY_BATCH_SIZE GREATER 1) + set(_binding_unity_dir "${CMAKE_CURRENT_BINARY_DIR}/binding_unity") + file(REMOVE_RECURSE "${_binding_unity_dir}") + file(MAKE_DIRECTORY "${_binding_unity_dir}") + + set(PYBIND11_COMPILE_SOURCES) + set(_binding_unity_index 0) + set(_binding_unity_count 0) + foreach(_src IN LISTS PYBIND11_UNITY_SOURCES) + if(_binding_unity_count EQUAL 0) + set(_binding_unity_src + "${_binding_unity_dir}/binding_unity_${_binding_unity_index}.cc") + file(WRITE "${_binding_unity_src}" + "// Generated by CMake to batch pybind11 wrapper sources.\n") + list(APPEND PYBIND11_COMPILE_SOURCES "${_binding_unity_src}") + endif() + + file(APPEND "${_binding_unity_src}" "#include \"${_src}\"\n") + + math(EXPR _binding_unity_count "${_binding_unity_count} + 1") + if(_binding_unity_count GREATER_EQUAL INFINIOPS_BINDING_UNITY_BATCH_SIZE) + math(EXPR _binding_unity_index "${_binding_unity_index} + 1") + set(_binding_unity_count 0) + endif() + endforeach() + + list(LENGTH PYBIND11_UNITY_SOURCES _binding_source_count) + list(LENGTH PYBIND11_COMPILE_SOURCES _binding_unity_source_count) + message(STATUS + "Binding unity build: ${_binding_source_count} sources batched into " + "${_binding_unity_source_count} translation units") + endif() + endif() + list(APPEND PYBIND11_COMPILE_SOURCES ${PYBIND11_DISPATCH_SOURCES}) + # TODO: There might be a better solution. if(WITH_NVIDIA OR WITH_ILUVATAR) - set_source_files_properties(${PYBIND11_SOURCES} PROPERTIES LANGUAGE CUDA) + set_source_files_properties(${PYBIND11_COMPILE_SOURCES} PROPERTIES LANGUAGE CUDA) endif() find_package(Python COMPONENTS Interpreter Development) @@ -366,9 +547,24 @@ if(GENERATE_PYTHON_BINDINGS) find_package(pybind11 CONFIG) if(PYBIND11_ENABLE_EXTRAS) - pybind11_add_module(ops ${PYBIND11_SOURCES}) + pybind11_add_module(ops ${PYBIND11_COMPILE_SOURCES}) else() - pybind11_add_module(ops NO_EXTRAS ${PYBIND11_SOURCES}) + pybind11_add_module(ops NO_EXTRAS ${PYBIND11_COMPILE_SOURCES}) + endif() + + if(WITH_TORCH AND CMAKE_GENERATOR MATCHES "Ninja") + set(INFINIOPS_BINDING_COMPILE_JOBS "2" CACHE STRING + "Maximum concurrent generated pybind11 binding compilations") + set_property(GLOBAL APPEND PROPERTY JOB_POOLS + binding_compile=${INFINIOPS_BINDING_COMPILE_JOBS}) + set_property(TARGET ops PROPERTY JOB_POOL_COMPILE binding_compile) + endif() + + if(WITH_METAX) + target_compile_options(ops PRIVATE "-x" "maca") + endif() + if(WITH_MOORE) + target_compile_options(ops PRIVATE "-x" "musa") endif() target_include_directories(ops PRIVATE ${PROJECT_SOURCE_DIR}) @@ -392,4 +588,11 @@ if(GENERATE_PYTHON_BINDINGS) file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/__init__.py" "") install(FILES "${CMAKE_CURRENT_BINARY_DIR}/__init__.py" DESTINATION .) + + if(WITH_TORCH) + # Ship the per-op metadata alongside the bindings so the unified + # torch op test can discover what to exercise at runtime. + install(FILES "${PROJECT_SOURCE_DIR}/generated/torch_ops_metadata.json" + DESTINATION .) + endif() endif() From 0016a27129118d2a8202da119c66b6120f4e1e97 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Sat, 16 May 2026 09:42:01 +0800 Subject: [PATCH 4/4] test(torch): add generated backend coverage --- tests/conftest.py | 16 +- tests/test_generate_torch_ops.py | 33 ++ tests/test_torch_ops.py | 541 +++++++++++++++++++++++++++++++ 3 files changed, 589 insertions(+), 1 deletion(-) create mode 100644 tests/test_generate_torch_ops.py create mode 100644 tests/test_torch_ops.py diff --git a/tests/conftest.py b/tests/conftest.py index 86d01c249..b38c7f574 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -75,6 +75,7 @@ def _clear_operator_caches(): # `uint16`/`uint32`/`uint64`. for _bits in (16, 32, 64): _t = getattr(torch, f"uint{_bits}", None) + if _t is not None: _NPU_UNSUPPORTED_DTYPES.add(_t) @@ -100,6 +101,7 @@ def skip_op_without_platform_impl(request): `pytest_generate_tests` already prunes empty-impl pairs at collection time, making this check redundant (and wasteful) on those tests. """ + if not hasattr(request.node, "callspec"): return @@ -144,6 +146,7 @@ def _set_random_seed(seed): def _active_device_selectors_for_torch_device(config, torch_device): """Return platform or torch device names selected for a torch device type.""" + if not torch_device: return () @@ -159,11 +162,13 @@ def _active_device_selectors_for_torch_device(config, torch_device): # The pybind layer maps torch device names (e.g. "cuda") to the backend # compiled into the current wheel, avoiding probes of inactive CUDA siblings. + return (torch_device,) def _resolve_device(name): """Map a platform name (e.g., `ascend`) to a PyTorch device type (e.g., `npu`).""" + return _PLATFORM_TO_TORCH_DEVICE.get(name, name) @@ -301,7 +306,16 @@ def pytest_pyfunc_call(pyfuncitem): rtol = payload.rtol atol = payload.atol - assert torch.allclose(output, expected, rtol=rtol, atol=atol) + # `torch.allclose` rejects `bool` dtypes — use `torch.equal` for + # non-floating outputs (bool, int) so comparison ops work. Pass + # `equal_nan=True` so NaN-in-both-positions (common for special + # functions fed out-of-domain inputs) does not fail the test. + if output.dtype.is_floating_point: + assert torch.allclose( + output, expected, rtol=rtol, atol=atol, equal_nan=True + ) + else: + assert torch.equal(output, expected) return True diff --git a/tests/test_generate_torch_ops.py b/tests/test_generate_torch_ops.py new file mode 100644 index 000000000..456535567 --- /dev/null +++ b/tests/test_generate_torch_ops.py @@ -0,0 +1,33 @@ +import importlib.util +import pathlib +import sys + + +def _load_generator_module(): + path = ( + pathlib.Path(__file__).resolve().parents[1] + / "scripts" + / "generate_torch_ops.py" + ) + spec = importlib.util.spec_from_file_location("generate_torch_ops_under_test", path) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + sys.modules[spec.name] = module + spec.loader.exec_module(module) + + return module + + +def test_load_aten_yaml_uses_packaged_torchgen(monkeypatch): + module = _load_generator_module() + monkeypatch.setattr(module, "_load_packaged_aten_yaml", lambda: "packaged: true\n") + + assert module._load_aten_yaml("v9.9.9") == "packaged: true\n" + + +def test_public_op_name_normalizes_aten_internal_and_inplace_names(): + module = _load_generator_module() + + assert module._public_op_name("_softmax") == "aten_softmax" + assert module._public_op_name("add_") == "add_inplace" + assert module._public_op_name("_add_relu_") == "aten_add_relu_inplace" diff --git a/tests/test_torch_ops.py b/tests/test_torch_ops.py new file mode 100644 index 000000000..b8c54773d --- /dev/null +++ b/tests/test_torch_ops.py @@ -0,0 +1,541 @@ +"""Unified test for every operator emitted by `generate_torch_ops.py`. + +The generator writes `generated/torch_ops_metadata.json` listing every op +with full per-parameter info (`name`, `type`, `is_tensor`, `is_out`). +A single parametrized test reads that metadata, builds inputs from the +parameter list, calls the InfiniOps wrapper and the torch reference, and +compares each output tensor. Adding an op to `scripts/torch_ops.yaml` +extends coverage with no test changes. +""" + +import json +import pathlib +import re + +import infini.ops +import pytest +import torch + +from tests.utils import clone_strided, randn_strided + +# PyTorch backends are emitted at this slot — see `_PYTORCH_SLOT` in +# `scripts/generate_torch_ops.py`. +_PYTORCH_SLOT = 8 + +_INSTALLED_METADATA_PATH = ( + pathlib.Path(infini.ops.__file__).resolve().with_name("torch_ops_metadata.json") +) +_SOURCE_METADATA_PATH = ( + pathlib.Path(__file__).resolve().parent.parent + / "generated" + / "torch_ops_metadata.json" +) + +_METADATA_PATH = next( + ( + path + for path in (_INSTALLED_METADATA_PATH, _SOURCE_METADATA_PATH) + if path.exists() + ), + _SOURCE_METADATA_PATH, +) +_METADATA = ( + json.loads(_METADATA_PATH.read_text()) if _METADATA_PATH.exists() else {"ops": []} +) + +_SHAPES = ( + (13, 4), + (13, 4, 4), + (4, 4, 5632), +) + +_DTYPES = ( + (torch.float32, 1e-5, 1e-5), + (torch.float16, 1e-2, 1e-2), + (torch.bfloat16, 1e-2, 1e-2), +) + +# Op-specific input shapes for matrix ops (`mm` etc.) which cannot use +# `randn_strided(shape)` for both inputs. The tuple is one shape per +# tensor input, in YAML order. +_TENSOR_SHAPES = { + "mm": ((8, 16), (16, 12)), + "bmm": ((4, 8, 16), (4, 16, 12)), + "matmul": ((8, 16), (16, 12)), + "dot": ((16,), (16,)), + "vdot": ((16,), (16,)), + "mv": ((8, 16), (16,)), + "inner": ((8, 16), (8, 16)), + "outer": ((8,), (12,)), + "ger": ((8,), (12,)), + "kron": ((3, 4), (2, 3)), +} + +# Per-(op, param-name) values for non-tensor inputs. Lookup falls back +# to a type-based default if no entry exists. +_SCALAR_VALUES = { + ("clamp_min", "min"): -0.5, + ("clamp_max", "max"): 0.5, + ("leaky_relu", "negative_slope"): 0.01, + ("hardshrink", "lambd"): 0.5, + ("softshrink", "lambd"): 0.5, + ("mvlgamma", "p"): 2, + ("prod", "dim"): 0, + ("cumsum", "dim"): 0, + ("cumprod", "dim"): 0, + ("logcumsumexp", "dim"): 0, + ("cummax", "dim"): 0, + ("cummin", "dim"): 0, + ("softmax", "dim"): -1, + ("log_softmax", "dim"): -1, + ("threshold", "threshold"): 0.0, + ("threshold", "value"): 0.0, + ("hardtanh", "min_val"): -1.0, + ("hardtanh", "max_val"): 1.0, + ("softplus", "beta"): 1.0, + ("softplus", "threshold"): 20.0, + ("elu", "alpha"): 1.0, + ("elu", "scale"): 1.0, + ("elu", "input_scale"): 1.0, + ("sub", "alpha"): 1.0, + ("addcmul", "value"): 1.0, + ("addcdiv", "value"): 1.0, + # `str reduce` modes accepted by the corresponding ATen kernels. + ("index_reduce", "reduce"): "amax", + ("scatter_reduce", "reduce"): "amax", + ("scatter_reduce_two", "reduce"): "amax", + # `int dim` for ops where 0 is a safe choice for our test shapes. + ("kthvalue_values", "k"): 1, + ("kthvalue_values", "dim"): 0, + ("mode_values", "dim"): 0, +} + +_TYPE_DEFAULTS = {"int": 0, "SymInt": 0, "bool": False, "str": "none"} + +# Mirrors `kStringToDataType` in `src/data_type.h`. Any tensor passed to +# an InfiniOps op must have one of these dtypes; others (`bool`, complex, +# quantised types) abort the process inside `DataTypeFromString`. Some +# vendor torch forks lag behind upstream and lack `uint16` / `uint32` / +# `uint64` (added in PyTorch 2.3); resolve them lazily and keep the +# attributes that actually exist. +_SUPPORTED_DTYPE_NAMES = ( + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "float16", + "bfloat16", + "float32", + "float64", +) +_SUPPORTED_DTYPES = frozenset( + getattr(torch, name) for name in _SUPPORTED_DTYPE_NAMES if hasattr(torch, name) +) + + +_LIST_SIZE_RE = re.compile(r"\[(\d+)\]") + + +def _is_inplace_aten_name(name): + """Return whether `name` is an ATen in-place operator name.""" + + return name.endswith("_") and not name.endswith("__") + + +def _list_default(aten_type): + """Default value for a required `int[N]` / `SymInt[N]` param. Most + such params name a `dim` or `kernel_size`; `[0]` works for `dim` and + causes `kernel_size`-style ops to fail their reference call cleanly, + which the test then skips.""" + size_match = _LIST_SIZE_RE.search(aten_type) + n = int(size_match.group(1)) if size_match else 1 + + return [0] * n + + +# Errors emitted by upstream PyTorch and vendor-forked variants for +# unsupported (op, dtype, device) combinations. We skip rather than fail +# on these — the gap is in PyTorch, not InfiniOps. +_VENDOR_SKIP_PATTERNS = ( + "not implemented for", # upstream PyTorch + "CNNL_STATUS_BAD_PARAM", # `torch_mlu` (Cambricon) + "MUDNN failed", # `torch_musa` (Moore) + "Could not run", # missing dispatcher entry on this backend + "don't support tensor dtype", # `torch_mlu` dtype check + "unknown format type", # `torch_npu` format descriptor gap + "result requires dtype", # output dtype mismatch (e.g. `float_power`) + # ATen kernels for some loss ops (`mse_loss`, `huber_loss`, …) use + # the `out` buffer as intermediate scratch and resize it before the + # final reduction. Our `from_blob` outputs are non-resizable, so + # the kernel aborts the call with this message. Skip these — the + # zero-copy wrapper can't drive that codepath. + "Trying to resize storage that is not resizable", +) + +# Random-sampling ops never match a fresh torch reference call — +# they consume RNG state and return different draws. Skip rather +# than try to align the two PRNG streams. +_RANDOM_OPS = frozenset( + { + "bernoulli", + "bernoulli_", + "multinomial", + "normal", + "rand", + "randn", + "randint", + "randperm", + "rrelu_with_noise", + } +) + +# Ops whose vendor kernel hangs indefinitely on at least one platform +# (`mode` on `torch_musa` for MUSA tensors). Skip until the vendor +# fixes the underlying kernel — letting the CI block on a hanging +# kernel costs ~30 min per platform run. +_VENDOR_HANG_OPS = frozenset( + { + "mode", + } +) + +# Ops whose vendor kernel crashes the Python process, so they must be skipped +# before calling into the InfiniOps/PyTorch slot. +_VENDOR_CRASH_OPS = frozenset( + { + ("npu", "mish"), + ("npu", "nuclear_norm"), + ("npu", "_linalg_svd"), + ("npu", "svd"), + } +) + +# Ops where the ATen `_out` schema and the Python reference (`torch.`, +# `torch.nn.functional.`) diverge in positional-argument ordering, so +# the harness's purely-positional reference call lands an InfiniOps +# argument on the wrong reference parameter. E.g. ATen +# `binary_cross_entropy_out(self, target, weight=None, reduction=Mean, out)` +# has `weight` between `target` and `reduction`; with `weight` hidden as +# `Tensor?`, our visible signature is `(self, target, reduction, out)`, +# but `torch.nn.functional.binary_cross_entropy(input, target, weight, +# reduction)` reads our `reduction:int` as `weight:Tensor` and crashes +# inside `weight.size()`. The InfiniOps wrapper itself is fine; only +# the harness's reference call is wrong. +_REFERENCE_SIGNATURE_MISMATCH_OPS = frozenset( + { + "binary_cross_entropy", + "binary_cross_entropy_backward", + } +) + +# Full reductions with low-precision inputs diverge between the functional +# (`torch.(x)`) and `_out` paths because of intermediate-precision +# choices we cannot align from outside ATen. +_LARGE_REDUCTION_OPS = frozenset( + {"sum", "mean", "nansum", "nanmean", "prod", "std", "var"} +) + +# Ops with input-domain `TORCH_CHECK` macros that fire as device-side +# `assert` on CUDA when our generic random fp32 inputs fall outside the +# expected range. The Python-side `RuntimeError` is catchable, but the +# CUDA context is left poisoned and every subsequent test errors at +# setup. Skip these on cuda; the CPU path raises a clean exception +# that the existing harness already handles. +_DEVICE_ASSERTING_OPS = frozenset( + { + "binary_cross_entropy", # requires inputs in [0, 1] + "multi_margin_loss", + "multilabel_margin_loss", + "nll_loss", + "nll_loss2d", + # cuDNN paths divide by `kernel_size`/`stride` and SIGFPE on the + # `[0, 0]` defaults our harness substitutes for required `int[N]` + # parameters. + "cudnn_convolution", + "slow_conv3d", + "slow_conv_transpose2d", + "slow_conv_transpose3d", + "thnn_conv2d", + "im2col", + "col2im", + "max_unpool2d", + "max_unpool3d", + "reflection_pad1d", + "reflection_pad2d", + "reflection_pad3d", + "replication_pad1d", + "replication_pad2d", + "replication_pad3d", + "upsample_bicubic2d", + "upsample_bilinear2d", + "upsample_linear1d", + "upsample_nearest1d", + "upsample_nearest2d", + "upsample_nearest3d", + "upsample_trilinear3d", + "avg_pool2d", + "avg_pool3d", + "max_pool2d_with_indices", + "max_pool3d_with_indices", + "adaptive_max_pool2d", + "adaptive_max_pool3d", + "adaptive_avg_pool2d", + "adaptive_avg_pool3d", + } +) + + +def _torch_func(op_name): + """Resolve the reference function across `torch`, `torch.special`, + and `torch.nn.functional`. `special_` falls through to + `torch.special.` with the prefix stripped.""" + + if _is_inplace_aten_name(op_name): + method_name = op_name + + def _call_inplace(input, *args): + return getattr(input, method_name)(*args) + + return _call_inplace + + candidates = [ + (torch, op_name), + (torch.special, op_name), + (torch.nn.functional, op_name), + ] + + if op_name.startswith("special_"): + candidates.append((torch.special, op_name.removeprefix("special_"))) + + for namespace, attr in candidates: + func = getattr(namespace, attr, None) + + if func is not None: + return func + + pytest.skip(f"no reference function for `{op_name}` in PyTorch") + + +def _pascal(snake_name): + return "".join(part.capitalize() for part in snake_name.split("_")) + + +def _skip_if_not_active(op_name, device): + op_class = getattr(infini.ops, _pascal(op_name), None) + + if op_class is None: + pytest.skip(f"`{op_name}` class not exposed on this build") + + if _PYTORCH_SLOT not in op_class.active_implementation_indices(device): + pytest.skip(f"`{op_name}` slot {_PYTORCH_SLOT} not active on `{device}`") + + +def _skip_low_precision_reduction(op_name, dtype, device): + if op_name in _LARGE_REDUCTION_OPS: + if dtype in (torch.float16, torch.bfloat16): + pytest.skip(f"`{op_name}` precision diverges on fp16/bf16") + + if device == "musa": + pytest.skip(f"`{op_name}` on `torch_musa` diverges from CPU reference") + + +def _build_input_value(op_name, param, shape, dtype, device, tensor_idx): + """Build the value passed to a non-out parameter.""" + + if param["is_tensor"]: + per_op = _TENSOR_SHAPES.get(op_name) + tshape = per_op[tensor_idx] if per_op is not None else shape + + return randn_strided(tshape, None, dtype=dtype, device=device) + + key = (op_name, param["name"]) + + if key in _SCALAR_VALUES: + return _SCALAR_VALUES[key] + + t = param["type"] + + if t.startswith(("int[", "SymInt[")) or t in {"int[]", "SymInt[]"}: + return _list_default(t) + + return _TYPE_DEFAULTS.get(t, 0.5) + + +def _call_infini(op_name, *args): + try: + getattr(infini.ops, op_name)(*args, implementation_index=_PYTORCH_SLOT) + except RuntimeError as exc: + if any(p in str(exc) for p in _VENDOR_SKIP_PATTERNS): + pytest.skip(f"`{op_name}` unsupported by torch on this device/dtype") + + raise + + +def _assert_close(actual, expected, rtol, atol): + if actual.dtype.is_floating_point: + assert torch.allclose(actual, expected, rtol=rtol, atol=atol, equal_nan=True) + else: + assert torch.equal(actual, expected) + + +def _testable_ops(): + """Filter the metadata down to ops the harness can drive. + + When multiple ATen overloads share the same `aten_name` they all + end up under one InfiniOps class (e.g., `std.dim` and + `std.correction` both map to `Std`), but each has a distinct ATen + `_out` signature. The reference call we synthesize from + `op_meta['params']` only exercises one signature; the secondary + overloads either rely on hidden defaults whose ATen interpretation + differs from the Python wrapper's (`std.correction(self, dim=None, + correction=None, ...)` defaults to a different correction than + `torch.std(self)`), or expose a positional shape that the Python + reference does not accept (e.g., `binary_cross_entropy_out`'s + `reduction:int` lands on the reference's `weight:Tensor?`). Keep + only the first overload of each `aten_name`.""" + seen = set() + keep = [] + + for op in _METADATA.get("ops", []): + if op["aten_name"] in seen: + continue + + seen.add(op["aten_name"]) + keep.append(op) + + return keep + + +def _op_meta_id(op_meta): + if not isinstance(op_meta, dict): + return "empty" + + # Multiple ATen overloads now share a single class name (`scatter` covers + # `scatter.src`, `scatter.value`, `scatter.reduce`, ...) — disambiguate + # parametrize ids by appending the visible parameter type signature so + # pytest does not collapse them into duplicate ids. + + return op_meta["overload_name"] + + +@pytest.mark.parametrize("op_meta", _testable_ops(), ids=_op_meta_id) +@pytest.mark.parametrize("shape", _SHAPES, ids=lambda s: "x".join(map(str, s))) +@pytest.mark.parametrize(("dtype", "rtol", "atol"), _DTYPES) +def test_op(op_meta, shape, dtype, device, rtol, atol): + op_name = op_meta["name"] + aten_name = op_meta.get("aten_name", op_name) + is_inplace = _is_inplace_aten_name(aten_name) + _skip_if_not_active(op_name, device) + _skip_low_precision_reduction(aten_name, dtype, device) + + if aten_name in _RANDOM_OPS: + pytest.skip(f"`{aten_name}` is non-deterministic (independent draws diverge)") + + if aten_name in _REFERENCE_SIGNATURE_MISMATCH_OPS: + pytest.skip( + f"`{aten_name}`'s ATen `_out` and Python reference signatures " + "have different positional ordering" + ) + + if aten_name in _VENDOR_HANG_OPS: + pytest.skip(f"`{aten_name}` hangs on at least one vendor kernel") + + if (device, aten_name) in _VENDOR_CRASH_OPS: + pytest.skip(f"`{aten_name}` crashes on `{device}` vendor kernel") + + if device == "cuda" and aten_name in _DEVICE_ASSERTING_OPS: + pytest.skip( + f"`{aten_name}` triggers a CUDA device-side assert on random inputs" + ) + + in_params = ( + op_meta["params"] + if is_inplace + else [p for p in op_meta["params"] if not p["is_out"]] + ) + out_params = [p for p in op_meta["params"] if p["is_out"]] + + # Build inputs in YAML order. + inputs = [] + tensor_idx = 0 + + for p in in_params: + inputs.append( + _build_input_value(aten_name, p, shape, dtype, device, tensor_idx) + ) + + if p["is_tensor"]: + tensor_idx += 1 + + # Run the reference to discover output shape(s)/dtype(s). + # An op may reject our generic `randn(shape)` input with any of these + # exception types — the gap is in our test harness's input synthesis, + # not in the InfiniOps wrapper. + ref_inputs = [ + clone_strided(x) if isinstance(x, torch.Tensor) else x for x in inputs + ] + + try: + ref = _torch_func(aten_name)(*ref_inputs) + except ( + RuntimeError, + TypeError, + ValueError, + IndexError, + NotImplementedError, + ) as exc: + pytest.skip(f"`torch.{aten_name}` rejects these inputs: {exc}") + + ref_outs = ref if isinstance(ref, tuple) else (ref,) + + if is_inplace: + ref_outs = (ref_inputs[0],) + + if len(ref_outs) != len(out_params): + # The Python-facing function (e.g. `F.adaptive_max_pool2d`) often + # exposes a subset of the ATen `_out` schema's outputs (returning + # only `out`, hiding `indices` behind a `return_indices=True` + # kwarg). Without a per-op map of how to coax the full tuple + # out, skip — the InfiniOps wrapper itself is fine. + pytest.skip( + f"`{aten_name}` reference produced {len(ref_outs)} output(s); " + f"schema declares {len(out_params)}" + ) + + # InfiniOps `DataType` enumerates only int{8,16,32,64}, uint{8,16,32,64}, + # float{16,32,64}, and bfloat16. Tensors with any other torch dtype + # (`bool`, `complex64`, `complex128`, …) abort on `DataTypeFromString`, + # so skip the test rather than crash the process. + tensors = [*ref_outs, *(x for x in inputs if isinstance(x, torch.Tensor))] + unsupported = next( + (t.dtype for t in tensors if t.dtype not in _SUPPORTED_DTYPES), None + ) + + if unsupported is not None: + pytest.skip( + f"`{op_name}` uses dtype {unsupported} — not in InfiniOps `DataType`" + ) + + # On CUDA, `torch.empty_like` of a 0-element tensor gives a tensor + # whose `data_ptr()` is unregistered with the device; passing it + # through to the wrapper trips "pointer resides on host memory". + if any(t.numel() == 0 for t in ref_outs): + pytest.skip( + f"`{op_name}` produced 0-element output (unregistered data_ptr on cuda)" + ) + + if is_inplace: + _call_infini(op_name, *inputs) + _assert_close(inputs[0], ref_outs[0], rtol, atol) + + return + + outs = [torch.empty_like(t) for t in ref_outs] + _call_infini(op_name, *inputs, *outs) + + for actual, expected in zip(outs, ref_outs): + _assert_close(actual, expected, rtol, atol)