From d4b7d60af8a864de33ff7a63a9fdfe2d28e4200f Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 15 Jan 2026 15:40:28 -0800 Subject: [PATCH 01/31] Replace op_schema with op_signature Signed-off-by: Justin Chu --- onnxscript/_internal/autocast.py | 27 +++++++++++++------------- onnxscript/_internal/converter.py | 6 +++--- onnxscript/_internal/evaluator.py | 32 ++++++++++++------------------- 3 files changed, 29 insertions(+), 36 deletions(-) diff --git a/onnxscript/_internal/autocast.py b/onnxscript/_internal/autocast.py index bc3e16f79e..59732d0c08 100644 --- a/onnxscript/_internal/autocast.py +++ b/onnxscript/_internal/autocast.py @@ -8,9 +8,9 @@ import numpy as np import onnx import onnx.helper # noqa: TID251 -from onnx.defs import OpSchema from onnxscript import ir, tensor +from onnxscript.ir import _schemas if TYPE_CHECKING: from onnxscript._internal import converter @@ -126,7 +126,7 @@ def cast_pyvalue_to_os_tensor(pyvalue, dtype=None): def cast_inputs( get_type_info: Callable[[Any], Any], cast: Callable[[Any, Any], Any], - op_schema: OpSchema | None, + op_signature: _schemas.OpSignature | None, args, ) -> tuple[Any, ...]: """Uses schema specification to support a limited form of auto-casting. @@ -140,12 +140,15 @@ def cast_inputs( This is used by the converter in a static-mode, as well as by the eager-mode execution in a dynamic-mode. """ - if op_schema is None: + if op_signature is None: # Either an error or a custom op. # No checks/casts in this case. return tuple(cast(x, None) for x in args) - expected_inputs = op_schema.inputs + # Filter to get only input parameters (not AttributeParameters) + expected_inputs = [ + param for param in op_signature.params if isinstance(param, _schemas.Parameter) + ] # We make two passes. In the first pass, we identify known type-bindings for # type-variables: eg., {'T1' : np.float32, 'T2' : np.int32}. # In the second pass, we use these bindings to cast scalar-values to @@ -156,17 +159,15 @@ def cast_inputs( for i, x in enumerate(args): if i < len(expected_inputs): expected = expected_inputs[i] - elif expected_inputs[-1].option == OpSchema.FormalParameterOption.Variadic: + elif expected_inputs[-1].variadic: expected = expected_inputs[-1] - if not expected.is_homogeneous: - args_typevars.append((x, None)) - continue + # TODO(justinchuby): Handle is_homogeneous params else: raise ValueError( f"Number of actual parameters {len(args)} " f"exceeds number of formal parameters {len(expected_inputs)}." ) - typevar = expected.type_str + typevar = expected.type_constraint.name if "(" not in typevar: # typevar is an identifier, like "T" typeinfo = get_type_info(x) @@ -177,18 +178,18 @@ def cast_inputs( return tuple(cast_args) -def dynamic_cast_inputs(op_schema: OpSchema, args): +def dynamic_cast_inputs(op_signature: _schemas.OpSignature, args): """Used for autocast during eager-mode execution.""" def get_type_info(x): return x.dtype if isinstance(x, tensor.Tensor) else None - return cast_inputs(get_type_info, cast_pyvalue_to_os_tensor, op_schema, args) + return cast_inputs(get_type_info, cast_pyvalue_to_os_tensor, op_signature, args) def static_cast_inputs( converter_: converter.Converter, - op_schema: Optional[OpSchema], + op_signature: Optional[_schemas.OpSignature], args: Sequence[Optional[ir.Value]], ) -> tuple[str, ...]: """Used for autocast during script-translation. @@ -212,4 +213,4 @@ def cast_like(x: Optional[ir.Value], y: Optional[ir.Value]) -> Optional[str]: return converter_.emit1([x_cast], "CastLike", [x, y]) return x - return cast_inputs(get_type_info, cast_like, op_schema, args) + return cast_inputs(get_type_info, cast_like, op_signature, args) diff --git a/onnxscript/_internal/converter.py b/onnxscript/_internal/converter.py index dd902ac7ab..6ab228ef4d 100644 --- a/onnxscript/_internal/converter.py +++ b/onnxscript/_internal/converter.py @@ -887,7 +887,7 @@ def _translate_call_expr( else: args = [self._translate_opt_expr(x) for x in node.args] attrs = [self._translate_attr(x.arg, x.value) for x in node.keywords] - args = autocast.static_cast_inputs(self, callee.op_schema, args) + args = autocast.static_cast_inputs(self, callee.op_signature, args) # In ONNX, there is no way to explicitly specify a None value for an attribute. # Instead, the attribute must be omitted from the attribute list. @@ -896,8 +896,8 @@ def _translate_call_expr( return callee, args, attrs def _cast_like_binary_expression(self, op, left, right) -> tuple[ir.Value, ir.Value]: - schema = op.op_schema - return autocast.static_cast_inputs(self, schema, (left, right)) + op_signature = op.op_signature + return autocast.static_cast_inputs(self, op_signature, (left, right)) def _translate_binary_op_expr(self, node: ast.BinOp): op = type(node.op) diff --git a/onnxscript/_internal/evaluator.py b/onnxscript/_internal/evaluator.py index 1415733397..5eefe3d963 100644 --- a/onnxscript/_internal/evaluator.py +++ b/onnxscript/_internal/evaluator.py @@ -189,38 +189,35 @@ def eval( attributes: The ONNX attributes to the op. """ attributes = _unwrap_tensors_in_kwargs(attributes) - attributes, closure = self.adapt_attributes(schema, attributes) - inputs = self.adapt_inputs(schema, inputs) + attributes, closure = self._adapt_attributes(attributes) + inputs = self._adapt_inputs(schema, inputs) outputs = self._eval(schema, inputs, attributes, closure) - return self.adapt_outputs(schema, outputs) + return self._adapt_outputs(outputs) - def adapt_inputs(self, schema: onnx.defs.OpSchema, inputs: Sequence[ExtendedModeValue]): + def _adapt_inputs(self, schema: onnx.defs.OpSchema, inputs: Sequence[ExtendedModeValue]): """Transform inputs to the expected format for the evaluator. Enables some syntactic sugar, such as the use of Python scalars, in a manner consistent with the translator. See autocast.py for details. """ - return autocast.dynamic_cast_inputs(schema, inputs) + op_signature = _schemas.OpSignature.from_op_schema(schema) + return autocast.dynamic_cast_inputs(op_signature, inputs) - def adapt_attributes( - self, schema: onnx.defs.OpSchema, attributes: Mapping[str, ExtendedModeValue] + def _adapt_attributes( + self, attributes: Mapping[str, ExtendedModeValue] ) -> tuple[dict[str, ExtendedModeValue], dict[str, ExtendedModeValue]]: """Transform attributes to the expected format for the evaluator. Returns: A closure that can be used to evaluate graph-valued attributes. """ - use_graph_attribute = self.use_graph_attribute(schema) closure: dict[Any, Any] = {} adapted_attributes = {} for k, v in attributes.items(): if isinstance(v, values.OnnxClosure): - if use_graph_attribute: - adapted_attributes[k] = v.function_ir.to_graph_proto() - for pyvar, onnxvar in v.function_ir.outer_scope_variables: - closure[onnxvar.value.name] = v.frame.f_locals[pyvar] - else: - adapted_attributes[k] = v.function + adapted_attributes[k] = v.function_ir.to_graph_proto() + for pyvar, onnxvar in v.function_ir.outer_scope_variables: + closure[onnxvar.value.name] = v.frame.f_locals[pyvar] elif callable(v): raise TypeError( f"Error: function-valued attribute {v.__name__} has no graph_proto" @@ -230,18 +227,13 @@ def adapt_attributes( adapted_attributes[k] = v return adapted_attributes, closure - def adapt_outputs(self, schema: onnx.defs.OpSchema, outputs: Sequence[EagerModeValue]): + def _adapt_outputs(self, outputs: Sequence[EagerModeValue]): """Adapt evaluator's output to convention used in onnxscript. Onnxscript uses a tuple/sequence only when number of outputs > 1. """ - del schema # unused return outputs[0] if len(outputs) == 1 else outputs - def use_graph_attribute(self, schema: onnx.defs.OpSchema): - del schema # unused - return True - @abc.abstractmethod def _eval( self, From 78a25d816aeb835ad511e5f9913902940e733140 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 15 Jan 2026 20:47:11 -0800 Subject: [PATCH 02/31] wip Signed-off-by: Justin Chu --- onnxscript/_internal/evaluator.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/onnxscript/_internal/evaluator.py b/onnxscript/_internal/evaluator.py index 5eefe3d963..edba055236 100644 --- a/onnxscript/_internal/evaluator.py +++ b/onnxscript/_internal/evaluator.py @@ -188,29 +188,32 @@ def eval( inputs: The ONNX inputs to the op. attributes: The ONNX attributes to the op. """ + op_signature = _schemas.OpSignature.from_op_schema(schema) attributes = _unwrap_tensors_in_kwargs(attributes) - attributes, closure = self._adapt_attributes(attributes) - inputs = self._adapt_inputs(schema, inputs) + attributes, closure = self._adapt_attributes(op_signature, attributes) + inputs = self._adapt_inputs(op_signature, inputs) outputs = self._eval(schema, inputs, attributes, closure) return self._adapt_outputs(outputs) - def _adapt_inputs(self, schema: onnx.defs.OpSchema, inputs: Sequence[ExtendedModeValue]): + def _adapt_inputs( + self, op_signature: _schemas.OpSignature, inputs: Sequence[ExtendedModeValue] + ): """Transform inputs to the expected format for the evaluator. Enables some syntactic sugar, such as the use of Python scalars, in a manner consistent with the translator. See autocast.py for details. """ - op_signature = _schemas.OpSignature.from_op_schema(schema) return autocast.dynamic_cast_inputs(op_signature, inputs) def _adapt_attributes( - self, attributes: Mapping[str, ExtendedModeValue] + self, op_signature, attributes: Mapping[str, ExtendedModeValue] ) -> tuple[dict[str, ExtendedModeValue], dict[str, ExtendedModeValue]]: """Transform attributes to the expected format for the evaluator. Returns: A closure that can be used to evaluate graph-valued attributes. """ + use_graph_attribute = self.use_graph_attribute(op_singature) closure: dict[Any, Any] = {} adapted_attributes = {} for k, v in attributes.items(): @@ -234,6 +237,12 @@ def _adapt_outputs(self, outputs: Sequence[EagerModeValue]): """ return outputs[0] if len(outputs) == 1 else outputs + + def use_graph_attribute(self, op_signature: _schemas.OpSignature) -> bool: + del schema # unused + return True + + @abc.abstractmethod def _eval( self, From f4e47bbf2120abe5b3c69c7894318a929a831345 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 16 Jan 2026 11:03:00 -0800 Subject: [PATCH 03/31] Clean up _to_model_proto Signed-off-by: Justin Chu --- onnxscript/_internal/evaluator.py | 49 ++++---- onnxscript/_internal/irbuilder.py | 8 +- onnxscript/_internal/values.py | 185 +++--------------------------- onnxscript/tensor.py | 2 + 4 files changed, 53 insertions(+), 191 deletions(-) diff --git a/onnxscript/_internal/evaluator.py b/onnxscript/_internal/evaluator.py index edba055236..9b5cb2eabe 100644 --- a/onnxscript/_internal/evaluator.py +++ b/onnxscript/_internal/evaluator.py @@ -138,6 +138,21 @@ def eval( inputs: The ONNX inputs to the op. attributes: The ONNX attributes to the op. """ + # Deprecated. Implement eval_op instead + + def eval_op( + self, + op: values.Op, + args: Sequence[ExtendedModeValue], + kwargs: Mapping[str, ExtendedModeValue], + ): + """Evaluates an Op. + + Args: + op: The Op to evaluate. + args: The positional arguments to the op. + kwargs: The keyword arguments to the op. + """ def eval_function( self, @@ -175,26 +190,6 @@ def __init__(self, ignore_unknown_function_kwargs: bool = False): """ self._ignore_unknown_function_kwargs = ignore_unknown_function_kwargs - def eval( - self, - schema: onnx.defs.OpSchema, - inputs: Sequence[ExtendedModeValue], - attributes: Mapping[str, Any], - ): - """Evaluates an ONNX op. - - Args: - schema: The OpSchema of the operator to evaluate. - inputs: The ONNX inputs to the op. - attributes: The ONNX attributes to the op. - """ - op_signature = _schemas.OpSignature.from_op_schema(schema) - attributes = _unwrap_tensors_in_kwargs(attributes) - attributes, closure = self._adapt_attributes(op_signature, attributes) - inputs = self._adapt_inputs(op_signature, inputs) - outputs = self._eval(schema, inputs, attributes, closure) - return self._adapt_outputs(outputs) - def _adapt_inputs( self, op_signature: _schemas.OpSignature, inputs: Sequence[ExtendedModeValue] ): @@ -260,6 +255,20 @@ def _eval( closure: The closure to use when evaluating graph-valued attributes. """ + def eval_op( + self, + op: values.Op, + args: Sequence[ExtendedModeValue], + kwargs: Mapping[str, ExtendedModeValue], + ): + op_signature = op.op_signature + assert op_signature is not None, f"Op {op.name} has no signature." + attributes = _unwrap_tensors_in_kwargs(kwargs) + attributes, closure = self._adapt_attributes(op_signature, attributes) + inputs = self._adapt_inputs(op_signature, args) + outputs = self._eval(schema, inputs, attributes, closure) + return self._adapt_outputs(outputs) + def eval_function( self, function: values.OnnxFunction, diff --git a/onnxscript/_internal/irbuilder.py b/onnxscript/_internal/irbuilder.py index e5fa80622e..f287b6b1ab 100644 --- a/onnxscript/_internal/irbuilder.py +++ b/onnxscript/_internal/irbuilder.py @@ -77,7 +77,7 @@ def append_parameter(self, parameter: ir.Value | ir.Attr) -> None: def add_nested_function(self, fun: IRFunction) -> None: self.nested_functions[fun.name] = fun - def get_called_functions(self) -> dict[str, onnx.FunctionProto]: + def get_called_functions(self) -> dict[str, ir.Function]: called_functions: dict[str, values.OnnxFunction] = {} def visit(function_ir: IRFunction): @@ -94,12 +94,12 @@ def add(f: values.OnnxFunction): visit(self) - return {name: f.to_function_proto() for name, f in called_functions.items()} + return {name: f.function_ir for name, f in called_functions.items()} def to_graph_proto(self) -> onnx.GraphProto: """Converts this instance into a `onnx.GraphProto`.""" - return ir.to_proto(self.graph) + return ir.serde.serialize_graph(self.graph) def to_function_proto(self) -> onnx.FunctionProto: """Converts this instance into a `onnx.FunctionProto`.""" - return ir.to_proto(self) + return ir.serde.serialize_function(self) diff --git a/onnxscript/_internal/values.py b/onnxscript/_internal/values.py index 2f22e1eefa..d87e661f46 100644 --- a/onnxscript/_internal/values.py +++ b/onnxscript/_internal/values.py @@ -4,6 +4,7 @@ # ruff: noqa: TID251 from __future__ import annotations +from collections.abc import Collection import dataclasses import functools @@ -168,9 +169,6 @@ def name(self) -> str: ... @property def opset(self) -> Opset: ... - @property - def op_schema(self) -> Optional[onnx.defs.OpSchema]: ... - @property def op_signature(self) -> Optional[_schemas.OpSignature]: ... @@ -203,15 +201,16 @@ def __init__( ) def __call__(self, *args, **kwargs): - # FIXME(after #225): Move import to the top of the file. from onnxscript._internal import evaluator # pylint: disable=import-outside-toplevel - schema = self.op_schema - if schema is None: - raise RuntimeError( - f"Op '{self.name}' does not have an OpSchema and cannot be evaluated." - ) - return evaluator.default().eval(schema, args, kwargs) + default_evaluator = evaluator.default() + if hasattr(default_evaluator, "eval"): + # Interface prior to onnxscript 0.6, used by PyTorch 2.10 and older + if self.op_schema is None: + raise ValueError(f"OpSchema not found for op '{self.name}'.") + return default_evaluator.eval(self.op_schema, args, kwargs) + # Use the new interface + return evaluator.default().eval_op(self, args, kwargs) @property def name(self) -> str: @@ -225,10 +224,6 @@ def opset(self) -> Opset: def op_schema(self) -> Optional[onnx.defs.OpSchema]: return self._op_schema - def has_schema(self) -> bool: - """Returns True if this op has an OpSchema.""" - return self.op_schema is not None - @property def op_signature(self) -> Optional[_schemas.OpSignature]: """Returns the signature of this op.""" @@ -261,99 +256,6 @@ class OnnxClosure: function: Any -@dataclasses.dataclass -class TypeConstraint: - """Represents a type constraint for an ONNX op. - - Attributes: - name: The name of the type constraint. - allowed_types: The allowed types for the type constraint. - """ - - name: str - allowed_types: list[str] - description: str = "" - - def as_tuple(self) -> tuple[str, list[str], str]: - """Returns the type constraint as a tuple.""" - return (self.name, self.allowed_types, self.description) - - -def _op_schema_from_function_ir( - function_ir: irbuilder.IRFunction, opset: Opset -) -> onnx.defs.OpSchema: - """Construct an ONNX OpSchema from an IRFunction.""" - - # Find all distinct types in the inputs and outputs - distinct_types = {_typeinfo(arg) for arg in function_ir.inputs}.union( - {_typeinfo(arg) for arg in function_ir.outputs} - ) - # Create a mapping from type to a unique name - type_to_constraint = {} - for i, type_ in enumerate(distinct_types): - name = f"T{i}" - type_to_constraint[type_] = TypeConstraint( - name=type_annotation.get_type_constraint_name(type_) or name, - allowed_types=type_annotation.pytype_to_type_strings(type_), - ) - - formal_inputs = [ - onnx.defs.OpSchema.FormalParameter( - arg.name, - type_to_constraint[_typeinfo(arg)].name, - param_option=( - onnx.defs.OpSchema.FormalParameterOption.Optional - if type_annotation.is_optional(_typeinfo(arg)) - else onnx.defs.OpSchema.FormalParameterOption.Single - ), - # TODO(justinchu): Check this is_homogeneous thing - is_homogeneous=True, - ) - for arg in function_ir.inputs - ] - formal_outputs = [ - onnx.defs.OpSchema.FormalParameter( - arg.name, - type_to_constraint[_typeinfo(arg)].name, - param_option=( - onnx.defs.OpSchema.FormalParameterOption.Optional - if type_annotation.is_optional(_typeinfo(arg)) - else onnx.defs.OpSchema.FormalParameterOption.Single - ), - # TODO(justinchu): Check this is_homogeneous thing - is_homogeneous=True, - ) - for arg in function_ir.outputs - ] - return onnx.defs.OpSchema( - function_ir.name, - opset.domain, - since_version=opset.version, - doc=function_ir.doc_string or "", - inputs=formal_inputs, - outputs=formal_outputs, - type_constraints=[constraint.as_tuple() for constraint in type_to_constraint.values()], - attributes=[ - *[ - onnx.defs.OpSchema.Attribute( - attr.name, - type=onnx.defs.OpSchema.AttrType(attr.type), # type: ignore[call-arg] - ) - for attr in function_ir.attrs - if attr.value is None - ], - *[ - onnx.defs.OpSchema.Attribute( - attr.name, - default_value=ir.to_proto(attr), - ) - for attr in function_ir.attrs - if attr.value is not None - ], - ], - ) - - class OnnxFunction(Op, Generic[_P, _R]): """Represents an ONNX op for which a function-body has been defined in onnxscript. @@ -399,27 +301,6 @@ def __init__( # Experimental fields self.traceable = False - @property - @deprecation.deprecated( - since="0.1", - removed_in="the future", - instructions="use '.name' instead", - ) - def opname(self) -> str: - # NOTE: This is a temporary alias for backward compatibility with PyTorch 2.0. - # TODO: Remove this in onnxscript 0.3. - return self.name - - @property - def op_schema(self) -> Optional[onnx.defs.OpSchema]: - """Construct an OpSchema from function_ir.""" - if self._op_schema is not None: - return self._op_schema - - self._op_schema = _op_schema_from_function_ir(self.function_ir, self.opset) - - return self._op_schema - @property def op_signature(self) -> Optional[_schemas.OpSignature]: """Returns the signature of this op.""" @@ -438,28 +319,8 @@ def op_signature(self) -> Optional[_schemas.OpSignature]: def op_signature(self, value: _schemas.OpSignature): self._signature = value - def __getitem__(self, instance): - """Returns a lambda to evaluate function using given evaluator instance. - - Usage: - script_fun(X) executes the function using the default evaluator instance. - script_fun[instance](X) executes the function using the given evaluator instance. - """ - - def fun(*args, **kwargs): - # FIXME(after #225): Move import to the top of the file. - from onnxscript._internal import ( # pylint: disable=import-outside-toplevel - evaluator, - ) - - with evaluator.default_as(instance): - return self.__call__(*args, **kwargs) - - return fun - def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: """Implements an eager-mode execution of an onnxscript function.""" - # FIXME(after #225): Move import to the top of the file. from onnxscript._internal import evaluator # pylint: disable=import-outside-toplevel return evaluator.default().eval_function(self, args, kwargs) # type: ignore[arg-type, return-value] @@ -490,7 +351,7 @@ def to_model_proto(self, **kwargs): def _to_model_proto( self, - functions=None, + functions: Collection[ir.Function] | None = None, io_types: Optional[ONNXType] = None, input_types: Optional[Sequence[ONNXType]] = None, output_types: Optional[Sequence[ONNXType]] = None, @@ -522,27 +383,15 @@ def _to_model_proto( if functions is None: sub_functions = self.function_ir.get_called_functions() functions = sub_functions.values() - else: - - def to_proto(f): - if isinstance(f, onnx.FunctionProto): - return f - if isinstance(f, OnnxFunction): - return f.to_function_proto() - raise TypeError("Expected a value of type FunctionProto of OnnxFunction") - - functions = [to_proto(f) for f in functions] # Determine opset imports opsets = self.function_ir.graph.opset_imports - for proto in functions: - if proto.domain not in opsets: - opsets[proto.domain] = 1 - # TODO(rama): Handle conflicts with appropriate error/warning message. - for opset in proto.opset_import: - if opset.domain not in opsets: - opsets[opset.domain] = opset.version + for func in functions: + if func.domain not in opsets: + opsets[func.domain] = 1 + + # No need to collect opsets from functions if "" not in opsets: # No operator is using the standard opset. @@ -559,8 +408,10 @@ def to_proto(f): # Create the model model = ir.Model(self.function_ir.graph, ir_version=ir_version) + for func in functions: + model.functions[func.identifier()] = func + model_proto = ir.to_proto(model) - model_proto.functions.extend(functions) # Set additional type information if provided graph = model_proto.graph diff --git a/onnxscript/tensor.py b/onnxscript/tensor.py index f1d781b808..6ad8f6bf12 100644 --- a/onnxscript/tensor.py +++ b/onnxscript/tensor.py @@ -16,6 +16,8 @@ class Tensor: Serves to define overloaded ops with an ONNX/ONNXScript semantics. """ + # TODO(justinchuby): Remove the tensor class and use ir.Value instead + def __init__(self, nparray: Optional[np.ndarray], opset=None): if nparray is not None and not isinstance(nparray, np.ndarray): raise TypeError( From be172c16d58a05bb76d09246b681e3f270e2ee54 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 16 Jan 2026 11:22:59 -0800 Subject: [PATCH 04/31] wip Signed-off-by: Justin Chu --- onnxscript/_internal/evaluator.py | 33 ++++++++++++++++++------------- onnxscript/_internal/values.py | 8 ++++---- onnxscript/ir/_schemas.py | 5 ++++- 3 files changed, 27 insertions(+), 19 deletions(-) diff --git a/onnxscript/_internal/evaluator.py b/onnxscript/_internal/evaluator.py index 9b5cb2eabe..fc12ce725c 100644 --- a/onnxscript/_internal/evaluator.py +++ b/onnxscript/_internal/evaluator.py @@ -208,14 +208,17 @@ def _adapt_attributes( Returns: A closure that can be used to evaluate graph-valued attributes. """ - use_graph_attribute = self.use_graph_attribute(op_singature) + use_graph_attribute = self.use_graph_attribute(op_signature) closure: dict[Any, Any] = {} adapted_attributes = {} for k, v in attributes.items(): if isinstance(v, values.OnnxClosure): - adapted_attributes[k] = v.function_ir.to_graph_proto() - for pyvar, onnxvar in v.function_ir.outer_scope_variables: - closure[onnxvar.value.name] = v.frame.f_locals[pyvar] + if use_graph_attribute: + adapted_attributes[k] = v.function_ir.to_graph_proto() + for pyvar, onnxvar in v.function_ir.outer_scope_variables: + closure[onnxvar.value.name] = v.frame.f_locals[pyvar] + else: + adapted_attributes[k] = v.function elif callable(v): raise TypeError( f"Error: function-valued attribute {v.__name__} has no graph_proto" @@ -234,7 +237,7 @@ def _adapt_outputs(self, outputs: Sequence[EagerModeValue]): def use_graph_attribute(self, op_signature: _schemas.OpSignature) -> bool: - del schema # unused + del op_signature # unused return True @@ -266,7 +269,7 @@ def eval_op( attributes = _unwrap_tensors_in_kwargs(kwargs) attributes, closure = self._adapt_attributes(op_signature, attributes) inputs = self._adapt_inputs(op_signature, args) - outputs = self._eval(schema, inputs, attributes, closure) + outputs = self._eval(op.op_schema, inputs, attributes, closure) return self._adapt_outputs(outputs) def eval_function( @@ -285,6 +288,8 @@ def eval_function( kwargs: The keyword arguments to the function. """ op_signature = function.op_signature + if op_signature is None: + raise RuntimeError(f"Function {function.name} has no signature.") # Split happens in the evaluator instead of the OnnxFunction __call__ method # so that evaluators can control behaviors like whether to fill in default values for attributes. tagged_args, tagged_kwargs = param_manipulation.tag_arguments_with_signature( @@ -514,7 +519,7 @@ def _call_ort( return [_numpy_to_onnxscript_value(x) for x in result] -def _schema_id(schema: onnx.defs.OpSchema) -> tuple[str, str, int]: +def _op_identifier(schema) -> tuple[str, str, int]: return schema.name, schema.domain, schema.since_version @@ -562,13 +567,13 @@ def __init__(self) -> None: super().__init__() self._python_ops: dict[tuple[str, str, int], Any] = {} - def use_graph_attribute(self, schema: onnx.defs.OpSchema) -> bool: - return _schema_id(schema) not in self._python_ops + def use_graph_attribute(self, op_signature: _schemas.OpSignature) -> bool: + return _op_identifier(op_signature) not in self._python_ops def _eval(self, schema, inputs, attributes, closure): - schemaid = _schema_id(schema) - if schemaid in self._python_ops: - return self._python_ops[schemaid](inputs, attributes) + identifier = _op_identifier(schema) + if identifier in self._python_ops: + return self._python_ops[identifier](inputs, attributes) else: return super()._eval(schema, inputs, attributes, closure) @@ -576,8 +581,8 @@ def register(self, opset: values.Opset) -> Callable[[_T], _T]: assert opset is not None def decorator(function: _T) -> _T: - schema = opset[function.__name__] - self._python_ops[_schema_id(schema)] = function + op_signature = opset[function.__name__].op_signature + self._python_ops[_op_identifier(op_signature)] = function return function return decorator diff --git a/onnxscript/_internal/values.py b/onnxscript/_internal/values.py index d87e661f46..70f35d1c4f 100644 --- a/onnxscript/_internal/values.py +++ b/onnxscript/_internal/values.py @@ -4,7 +4,6 @@ # ruff: noqa: TID251 from __future__ import annotations -from collections.abc import Collection import dataclasses import functools @@ -12,7 +11,8 @@ import logging import types import typing -from typing import ( # type: ignore[attr-defined] +from collections.abc import Collection +from typing import ( Any, Callable, ClassVar, @@ -28,7 +28,7 @@ import onnx_ir as ir from typing_extensions import ParamSpec -from onnxscript._internal import ast_utils, deprecation, irbuilder, sourceinfo, type_annotation +from onnxscript._internal import ast_utils, irbuilder, sourceinfo from onnxscript._internal import converter as converter_module from onnxscript.ir import _schemas from onnxscript.onnx_types import ONNXType @@ -123,7 +123,7 @@ def __contains__(self, opname): def __str__(self) -> str: return self.domain - def __getattr__(self, attr: str): + def __getattr__(self, attr: str) -> Op: try: schema = onnx.defs.get_schema(attr, self.version, self.domain) return Op(self, attr, schema) diff --git a/onnxscript/ir/_schemas.py b/onnxscript/ir/_schemas.py index d4d88ab5bb..66f6875eb2 100644 --- a/onnxscript/ir/_schemas.py +++ b/onnxscript/ir/_schemas.py @@ -339,6 +339,7 @@ class OpSignature: params_map: Mapping[str, Parameter | AttributeParameter] = dataclasses.field( init=False, repr=False ) + since_version: int = 1 def __post_init__(self): self.params_map = {param.name: param for param in self.params} @@ -415,11 +416,12 @@ def from_op_schema(cls, op_schema: onnx.defs.OpSchema) -> OpSignature: overload="", params=params, outputs=outputs, + since_version=op_schema.since_version, ) @classmethod def from_function( - cls, func, domain: str, name: str | None = None, overload: str = "" + cls, func, domain: str, name: str | None = None, overload: str = "", since_version: int = 1 ) -> OpSignature: """Produce an OpSignature from a function using type annotation.""" @@ -545,4 +547,5 @@ def from_function( overload=overload, params=params, outputs=outputs, + since_version=since_version, ) From 38ecac75bd1929825c349cfa065bbc5d1582ff1c Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 16 Jan 2026 11:28:00 -0800 Subject: [PATCH 05/31] update Signed-off-by: Justin Chu --- onnxscript/_internal/evaluator.py | 2 -- onnxscript/_internal/values.py | 23 +++++++++-------------- onnxscript/ir/_schemas.py | 7 ++++++- 3 files changed, 15 insertions(+), 17 deletions(-) diff --git a/onnxscript/_internal/evaluator.py b/onnxscript/_internal/evaluator.py index fc12ce725c..d7e9d5b775 100644 --- a/onnxscript/_internal/evaluator.py +++ b/onnxscript/_internal/evaluator.py @@ -235,12 +235,10 @@ def _adapt_outputs(self, outputs: Sequence[EagerModeValue]): """ return outputs[0] if len(outputs) == 1 else outputs - def use_graph_attribute(self, op_signature: _schemas.OpSignature) -> bool: del op_signature # unused return True - @abc.abstractmethod def _eval( self, diff --git a/onnxscript/_internal/values.py b/onnxscript/_internal/values.py index 70f35d1c4f..015fa4875f 100644 --- a/onnxscript/_internal/values.py +++ b/onnxscript/_internal/values.py @@ -108,7 +108,8 @@ def __repr__(self): def __getitem__(self, opname): try: - return onnx.defs.get_schema(opname, self.version, self.domain) + schema = onnx.defs.get_schema(opname, self.version, self.domain) + return Op(self, opname, schema) except Exception: # pylint: disable=broad-except # TODO: more specific exception return None @@ -189,7 +190,13 @@ def __init__( ) -> None: self._opset = opset self._name = name - self._op_schema = op_schema or opset[name] + self._op_schema: onnx.defs.OpSchema | None + if op_schema is not None: + self._op_schema = op_schema + elif (op := opset[name]) is not None: + self._op_schema = op.op_schema + else: + self._op_schema = None self._signature: Optional[_schemas.OpSignature] = None if self._op_schema is None: @@ -484,18 +491,6 @@ def function_ir(self) -> irbuilder.IRFunction: return converter.translate_function_signature(func_ast) - @property - def op_schema(self) -> Optional[onnx.defs.OpSchema]: - """Return the OpSchema.""" - - if self._op_schema is not None: - return self._op_schema - - # FIXME(justinchuby): outputs are empty. Need to fix. - self._op_schema = _op_schema_from_function_ir(self.function_ir, self._opset) - - return self._op_schema - @property def op_signature(self) -> Optional[_schemas.OpSignature]: """Returns the signature of this op.""" diff --git a/onnxscript/ir/_schemas.py b/onnxscript/ir/_schemas.py index 66f6875eb2..ea8affc37d 100644 --- a/onnxscript/ir/_schemas.py +++ b/onnxscript/ir/_schemas.py @@ -421,7 +421,12 @@ def from_op_schema(cls, op_schema: onnx.defs.OpSchema) -> OpSignature: @classmethod def from_function( - cls, func, domain: str, name: str | None = None, overload: str = "", since_version: int = 1 + cls, + func, + domain: str, + name: str | None = None, + overload: str = "", + since_version: int = 1, ) -> OpSignature: """Produce an OpSignature from a function using type annotation.""" From 26151b25916528986b9b7d55955c2069d0cfda45 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 16 Jan 2026 11:38:52 -0800 Subject: [PATCH 06/31] Fixes Signed-off-by: Justin Chu --- onnxscript/_internal/evaluator.py | 19 ++++--------------- onnxscript/_internal/evaluator_test.py | 6 ++++-- onnxscript/_internal/values.py | 26 +++++++------------------- onnxscript/ir/_schemas.py | 2 +- 4 files changed, 16 insertions(+), 37 deletions(-) diff --git a/onnxscript/_internal/evaluator.py b/onnxscript/_internal/evaluator.py index d7e9d5b775..1f29f15d68 100644 --- a/onnxscript/_internal/evaluator.py +++ b/onnxscript/_internal/evaluator.py @@ -123,22 +123,11 @@ def _unwrap_tensors_in_kwargs(kwargs: Mapping[str, Any]) -> dict[str, Any]: @runtime_checkable class Evaluator(Protocol): - """Protocol for evaluating ONNX ops.""" + """Protocol for evaluating ONNX ops. - def eval( - self, - schema: onnx.defs.OpSchema, - inputs: Sequence[ExtendedModeValue], - attributes: Mapping[str, Any], - ): - """Evaluates an ONNX op. - - Args: - schema: The OpSchema of the operator to evaluate. - inputs: The ONNX inputs to the op. - attributes: The ONNX attributes to the op. - """ - # Deprecated. Implement eval_op instead + NOTE: The ``eval`` method was deprecated and removed. Implement ``eval_op`` + and ``eval_function`` instead. + """ def eval_op( self, diff --git a/onnxscript/_internal/evaluator_test.py b/onnxscript/_internal/evaluator_test.py index c696ddf9b4..4949c04675 100644 --- a/onnxscript/_internal/evaluator_test.py +++ b/onnxscript/_internal/evaluator_test.py @@ -31,11 +31,13 @@ def square(y: FLOAT["N"]) -> FLOAT["N"]: # noqa: F821 np.testing.assert_equal(output, expected) # Test using ort-mixed-evaluator - output = seq_map[evaluator.ort_mixed_evaluator](x) + with evaluator.default_as(evaluator.ort_mixed_evaluator): + output = seq_map(x) np.testing.assert_equal(output, expected) # Test using ort-evaluator - output = seq_map[evaluator.ort_evaluator](x) + with evaluator.default_as(evaluator.ort_evaluator): + output = seq_map(x) np.testing.assert_equal(output, expected) diff --git a/onnxscript/_internal/values.py b/onnxscript/_internal/values.py index 015fa4875f..e964078a05 100644 --- a/onnxscript/_internal/values.py +++ b/onnxscript/_internal/values.py @@ -300,7 +300,9 @@ def __init__( self.function_ir = irfun self.source = source self.kwargs = kwargs - self._op_schema: Optional[onnx.defs.OpSchema] = None + self._signature = _schemas.OpSignature.from_function( + self.function, domain=self.function_ir.domain, name=self.name + ) # Allow the object to be inspected as a function functools.update_wrapper(self, pyfun) @@ -311,15 +313,6 @@ def __init__( @property def op_signature(self) -> Optional[_schemas.OpSignature]: """Returns the signature of this op.""" - if self._signature is not None: - return self._signature - - if self.op_schema is None: - return None - - self._signature = _schemas.OpSignature.from_function( - self.function, domain=self.function_ir.domain, name=self.name - ) return self._signature @op_signature.setter @@ -400,6 +393,7 @@ def _to_model_proto( # No need to collect opsets from functions + # FIXME: Collect used opsets from the function nodes if "" not in opsets: # No operator is using the standard opset. # Use the specified version if provided or the default value. @@ -462,6 +456,9 @@ class TracedOnnxFunction(Op): def __init__(self, opset: Opset, func: Callable): super().__init__(opset, func.__name__) self.func = func + self._signature = _schemas.OpSignature.from_function( + self.func, domain="_traced", name=self.name + ) # Allow the object to be inspected as a function functools.update_wrapper(self, func) @@ -494,15 +491,6 @@ def function_ir(self) -> irbuilder.IRFunction: @property def op_signature(self) -> Optional[_schemas.OpSignature]: """Returns the signature of this op.""" - if self._signature is not None: - return self._signature - - if self.op_schema is None: - return None - - self._signature = _schemas.OpSignature.from_function( - self.func, domain="_traced", name=self.name - ) return self._signature @op_signature.setter diff --git a/onnxscript/ir/_schemas.py b/onnxscript/ir/_schemas.py index ea8affc37d..6d3a20bbed 100644 --- a/onnxscript/ir/_schemas.py +++ b/onnxscript/ir/_schemas.py @@ -441,7 +441,7 @@ def from_function( for param in py_signature.parameters.values(): if param.name not in type_hints: - logger.warning( + logger.debug( "Missing annotation for parameter '%s' from %s. Treating as an Input.", param.name, py_signature, From 0fe2f6de68e6996694bf82f0fc67309d6a86f3e0 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 16 Jan 2026 11:43:44 -0800 Subject: [PATCH 07/31] Fix converter Signed-off-by: Justin Chu --- onnxscript/_internal/converter.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/onnxscript/_internal/converter.py b/onnxscript/_internal/converter.py index 6ab228ef4d..1c1c0963ad 100644 --- a/onnxscript/_internal/converter.py +++ b/onnxscript/_internal/converter.py @@ -12,7 +12,6 @@ Union, ) -import onnx import onnx_ir as ir import onnxscript @@ -29,6 +28,7 @@ from onnxscript._internal import ( type_annotation as ta, ) +from onnxscript.ir import _schemas logger = logging.getLogger("onnxscript") @@ -518,7 +518,7 @@ def _translate_attr( self, attr_name: str, expr: ast.AST, - attr_meta: onnx.defs.OpSchema.Attribute | None = None, + attr_meta: _schemas.AttributeParameter | None = None, ) -> ir.Attr | None: """Translate an attribute-value specification of the form `attr_name=` in a call to an op. expr is an AST. The following cases are supported: @@ -880,14 +880,11 @@ def _translate_call_expr( op_signature, node.args, kwargs, fill_defaults=False ) args = [self._translate_opt_expr(x) for x in args] - attrs = [ - self._translate_attr(x, y, callee.op_schema.attributes[x]) - for x, y in attrs.items() - ] + attrs = [self._translate_attr(x, y, op_signature.get(x)) for x, y in attrs.items()] else: args = [self._translate_opt_expr(x) for x in node.args] attrs = [self._translate_attr(x.arg, x.value) for x in node.keywords] - args = autocast.static_cast_inputs(self, callee.op_signature, args) + args = autocast.static_cast_inputs(self, op_signature, args) # In ONNX, there is no way to explicitly specify a None value for an attribute. # Instead, the attribute must be omitted from the attribute list. From 0e067b6932fa34a48e4095d4e734b704ce9393ac Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 16 Jan 2026 11:49:11 -0800 Subject: [PATCH 08/31] homogeneous Signed-off-by: Justin Chu --- onnxscript/_internal/autocast.py | 4 +++- onnxscript/ir/_schemas.py | 9 ++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/onnxscript/_internal/autocast.py b/onnxscript/_internal/autocast.py index 59732d0c08..99d0e82f5b 100644 --- a/onnxscript/_internal/autocast.py +++ b/onnxscript/_internal/autocast.py @@ -161,7 +161,9 @@ def cast_inputs( expected = expected_inputs[i] elif expected_inputs[-1].variadic: expected = expected_inputs[-1] - # TODO(justinchuby): Handle is_homogeneous params + if not expected.homogeneous: + args_typevars.append((x, None)) + continue else: raise ValueError( f"Number of actual parameters {len(args)} " diff --git a/onnxscript/ir/_schemas.py b/onnxscript/ir/_schemas.py index 6d3a20bbed..1f14634834 100644 --- a/onnxscript/ir/_schemas.py +++ b/onnxscript/ir/_schemas.py @@ -106,8 +106,10 @@ class Parameter: type_constraint: TypeConstraintParam required: bool variadic: bool + homogeneous: bool = True + min_arity: int = 1 + # TODO: Add differentiation_category default: Any = _EMPTY_DEFAULT - # TODO: Add other properties too def __str__(self) -> str: type_str = self.type_constraint.name @@ -188,6 +190,8 @@ def _convert_formal_parameter( type_constraint=type_constraint, required=param.option != onnx.defs.OpSchema.FormalParameterOption.Optional, variadic=param.option == onnx.defs.OpSchema.FormalParameterOption.Variadic, + homogeneous=param.is_homogeneous, + min_arity=param.min_arity, ) @@ -455,6 +459,7 @@ def from_function( required=param.default is inspect.Parameter.empty, # TODO: Handle variadic variadic=False, + homogeneous=True, default=param.default if param.default is not inspect.Parameter.empty else _EMPTY_DEFAULT, @@ -505,6 +510,7 @@ def from_function( required=param.default is inspect.Parameter.empty, # TODO: Handle variadic variadic=False, + homogeneous=True, default=param.default if param.default is not inspect.Parameter.empty else _EMPTY_DEFAULT, @@ -542,6 +548,7 @@ def from_function( type_constraint=type_constraint, required=True, variadic=False, + homogeneous=True, default=_EMPTY_DEFAULT, ) ) From 1109f261196a21ba93a0b36fe4e280bb82153853 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 16 Jan 2026 12:07:11 -0800 Subject: [PATCH 09/31] fix call functions Signed-off-by: Justin Chu --- onnxscript/_internal/irbuilder.py | 4 ++-- onnxscript/_internal/values.py | 26 ++++++++++++-------------- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/onnxscript/_internal/irbuilder.py b/onnxscript/_internal/irbuilder.py index f287b6b1ab..1ae3c7bdb1 100644 --- a/onnxscript/_internal/irbuilder.py +++ b/onnxscript/_internal/irbuilder.py @@ -77,7 +77,7 @@ def append_parameter(self, parameter: ir.Value | ir.Attr) -> None: def add_nested_function(self, fun: IRFunction) -> None: self.nested_functions[fun.name] = fun - def get_called_functions(self) -> dict[str, ir.Function]: + def get_called_functions(self) -> dict[str, values.OnnxFunction]: called_functions: dict[str, values.OnnxFunction] = {} def visit(function_ir: IRFunction): @@ -94,7 +94,7 @@ def add(f: values.OnnxFunction): visit(self) - return {name: f.function_ir for name, f in called_functions.items()} + return called_functions def to_graph_proto(self) -> onnx.GraphProto: """Converts this instance into a `onnx.GraphProto`.""" diff --git a/onnxscript/_internal/values.py b/onnxscript/_internal/values.py index e964078a05..08dc6d32fc 100644 --- a/onnxscript/_internal/values.py +++ b/onnxscript/_internal/values.py @@ -11,7 +11,6 @@ import logging import types import typing -from collections.abc import Collection from typing import ( Any, Callable, @@ -351,7 +350,6 @@ def to_model_proto(self, **kwargs): def _to_model_proto( self, - functions: Collection[ir.Function] | None = None, io_types: Optional[ONNXType] = None, input_types: Optional[Sequence[ONNXType]] = None, output_types: Optional[Sequence[ONNXType]] = None, @@ -380,24 +378,24 @@ def _to_model_proto( An instance of :class:`onnx.ModelProto`. """ # Identify functions to include in the model - if functions is None: - sub_functions = self.function_ir.get_called_functions() - functions = sub_functions.values() + sub_functions = self.function_ir.get_called_functions() + functions = sub_functions.values() # Determine opset imports - opsets = self.function_ir.graph.opset_imports + opset_imports = self.function_ir.graph.opset_imports for func in functions: - if func.domain not in opsets: - opsets[func.domain] = 1 + domain = func.opset.domain + if domain is not None and domain not in opset_imports: + opset_imports[domain] = func.opset.version - # No need to collect opsets from functions + if "" not in opset_imports and "" in func.function_ir.opset_imports: + opset_imports[""] = func.function_ir.opset_imports[""] - # FIXME: Collect used opsets from the function nodes - if "" not in opsets: + if "" not in opset_imports: # No operator is using the standard opset. # Use the specified version if provided or the default value. - opsets[""] = ( + opset_imports[""] = ( opset_version if opset_version is not None else onnx.defs.onnx_opset_version() ) @@ -405,12 +403,12 @@ def _to_model_proto( if "ir_version" in kwargs: ir_version = kwargs.pop("ir_version") else: - ir_version = select_ir_version(opsets[""]) + ir_version = select_ir_version(opset_imports[""]) # Create the model model = ir.Model(self.function_ir.graph, ir_version=ir_version) for func in functions: - model.functions[func.identifier()] = func + model.functions[func.function_ir.identifier()] = func.function_ir model_proto = ir.to_proto(model) From 8bd5f52898318091561cce315116ac898e3c4442 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 16 Jan 2026 13:38:07 -0800 Subject: [PATCH 10/31] copilot Signed-off-by: Justin Chu --- onnxscript/_internal/evaluator.py | 10 ++++++++-- onnxscript/_internal/values.py | 2 +- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/onnxscript/_internal/evaluator.py b/onnxscript/_internal/evaluator.py index 1f29f15d68..d6075d29be 100644 --- a/onnxscript/_internal/evaluator.py +++ b/onnxscript/_internal/evaluator.py @@ -506,8 +506,14 @@ def _call_ort( return [_numpy_to_onnxscript_value(x) for x in result] -def _op_identifier(schema) -> tuple[str, str, int]: - return schema.name, schema.domain, schema.since_version +def _op_identifier( + op_schema_or_signature: onnx.defs.OpSchema | _schemas.OpSignature, +) -> tuple[str, str, int]: + return ( + op_schema_or_signature.name, + op_schema_or_signature.domain, + op_schema_or_signature.since_version, + ) class ORTEvaluator(BaseEvaluator): diff --git a/onnxscript/_internal/values.py b/onnxscript/_internal/values.py index 08dc6d32fc..007b50a036 100644 --- a/onnxscript/_internal/values.py +++ b/onnxscript/_internal/values.py @@ -382,7 +382,7 @@ def _to_model_proto( functions = sub_functions.values() # Determine opset imports - opset_imports = self.function_ir.graph.opset_imports + opset_imports = self.function_ir.graph.opset_imports.copy() for func in functions: domain = func.opset.domain From 2e96a0b440c4a0cb0b0eee28b0a30a2fb1b3e968 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 21 Jan 2026 10:34:18 -0800 Subject: [PATCH 11/31] update opset import Signed-off-by: Justin Chu --- onnxscript/_internal/values.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/onnxscript/_internal/values.py b/onnxscript/_internal/values.py index 007b50a036..62ad0e8b7c 100644 --- a/onnxscript/_internal/values.py +++ b/onnxscript/_internal/values.py @@ -382,7 +382,8 @@ def _to_model_proto( functions = sub_functions.values() # Determine opset imports - opset_imports = self.function_ir.graph.opset_imports.copy() + main_graph = self.function_ir.graph.clone() + opset_imports = main_graph.opset_imports for func in functions: domain = func.opset.domain @@ -406,7 +407,7 @@ def _to_model_proto( ir_version = select_ir_version(opset_imports[""]) # Create the model - model = ir.Model(self.function_ir.graph, ir_version=ir_version) + model = ir.Model(main_graph, ir_version=ir_version) for func in functions: model.functions[func.function_ir.identifier()] = func.function_ir From e1e17eea52a697d36e3d524c92da0534f564dcd8 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 21 Jan 2026 10:39:44 -0800 Subject: [PATCH 12/31] Support user functions Signed-off-by: Justin Chu --- onnxscript/_internal/values.py | 32 +++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/onnxscript/_internal/values.py b/onnxscript/_internal/values.py index 62ad0e8b7c..984b7ee26d 100644 --- a/onnxscript/_internal/values.py +++ b/onnxscript/_internal/values.py @@ -350,6 +350,7 @@ def to_model_proto(self, **kwargs): def _to_model_proto( self, + functions: Optional[Sequence[ir.Function | onnx.FunctionProto | OnnxFunction]] = None, io_types: Optional[ONNXType] = None, input_types: Optional[Sequence[ONNXType]] = None, output_types: Optional[Sequence[ONNXType]] = None, @@ -377,15 +378,32 @@ def _to_model_proto( Returns: An instance of :class:`onnx.ModelProto`. """ - # Identify functions to include in the model - sub_functions = self.function_ir.get_called_functions() - functions = sub_functions.values() - - # Determine opset imports + if functions is None: + # Identify functions to include in the model + sub_functions = self.function_ir.get_called_functions() + ir_functions = sub_functions.values() + else: + ir_functions = [] + for func in functions: + if isinstance(func, ir.Function): + ir_functions.append(func) + elif isinstance(func, onnx.FunctionProto): + ir_functions.append(ir.serde.deserialize_function(func)) + elif isinstance(func, OnnxFunction): + ir_functions.append(func.function_ir) + else: + raise TypeError( + f"functions must be a sequence of " + f"ir.Function, onnx.FunctionProto, or OnnxFunction, " + f"not {type(func)!r}." + ) + + # Duplicate the graph to create the model main_graph = self.function_ir.graph.clone() + # Determine opset imports opset_imports = main_graph.opset_imports - for func in functions: + for func in ir_functions: domain = func.opset.domain if domain is not None and domain not in opset_imports: opset_imports[domain] = func.opset.version @@ -408,7 +426,7 @@ def _to_model_proto( # Create the model model = ir.Model(main_graph, ir_version=ir_version) - for func in functions: + for func in ir_functions: model.functions[func.function_ir.identifier()] = func.function_ir model_proto = ir.to_proto(model) From 367c942e21f8c94350884420a5ec21733bcc3bae Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 21 Jan 2026 10:52:01 -0800 Subject: [PATCH 13/31] Fix an error where the function return value is not the same as the expression output Signed-off-by: Justin Chu --- onnxscript/_internal/converter.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/onnxscript/_internal/converter.py b/onnxscript/_internal/converter.py index 1c1c0963ad..16afb12b0c 100644 --- a/onnxscript/_internal/converter.py +++ b/onnxscript/_internal/converter.py @@ -1110,7 +1110,7 @@ def check_num_outputs(n): def ret(exp, i, suffix): preferred_name = f"return_val{suffix}" return_var = self._translate_expr(exp, preferred_name) # TODO(rama) - val = self._lookup(return_var.name, self._source_of(exp), False) + val = self._lookup(return_var.name, self._source_of(exp), raise_exception=False) if isinstance(val, values.SymbolValue) and isinstance(val.value, ir.Value): if val.value.is_graph_input(): # In ONNX, a graph-input cannot be an output of the graph. @@ -1121,13 +1121,9 @@ def ret(exp, i, suffix): # ONNX does not allow duplicate output names. return_var = self._emit_copy(return_var, f"{return_var}_copy") break - if self.returntype is None: - t = None - else: - t = self.returntype[i] - self._current_fn.outputs.append( - make_value(return_var.name, t, self._source_of(stmt)) - ) + if self.returntype is not None: + set_type_info(return_var, self.returntype[i]) + self._current_fn.outputs.append(return_var) return return_var val = stmt.value From 7bd47e7e0ea813cf321cba1dcf5d08cd022a9db0 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 21 Jan 2026 11:10:59 -0800 Subject: [PATCH 14/31] Reverse the scope lists for efficiency Signed-off-by: Justin Chu --- onnxscript/_internal/converter.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/onnxscript/_internal/converter.py b/onnxscript/_internal/converter.py index 16afb12b0c..8a5a82c92a 100644 --- a/onnxscript/_internal/converter.py +++ b/onnxscript/_internal/converter.py @@ -273,25 +273,25 @@ def _enter_scope(self, name: str, parent_node: ast.AST): """Enter a control-flow block (a loop body or if-then-else branch). The block is translated into a nested-scope in ONNX. """ - self._outer.insert(0, self._current_fn) + self._outer.append(self._current_fn) self._current_fn = irbuilder.IRFunction(name) - self._locals.insert(0, {}) + self._locals.append({}) logger.debug("Converter:_enter_scope:%d:node:%s", len(self._locals), type(parent_node)) def _exit_scope(self) -> irbuilder.IRFunction: """Exit from a control-flow block (a loop body or if-then-else branch).""" logger.debug("Converter:_exit_scope:%d", len(self._locals)) graph = self._current_fn - self._current_fn = self._outer.pop(0) - self._locals.pop(0) + self._current_fn = self._outer.pop() + self._locals.pop() return graph def _current_scope(self) -> dict[str, LocalSymValue]: - return self._locals[0] + return self._locals[-1] def _bind(self, name: str, val: LocalSymValue) -> None: logger.debug("Converter:_bind:%s", name) - self._locals[0][name] = val + self._locals[-1][name] = val def _lookup( self, name: str, info: sourceinfo.SourceInfo, raise_exception: bool = True @@ -302,7 +302,7 @@ def _lookup( cases include: constant values or functions (mapped to Graph attributes), etc. """ - for scope in self._locals: + for scope in reversed(self._locals): if name in scope: return scope[name] if name in self.globals: @@ -1392,7 +1392,7 @@ def _translate_block( self._current_fn.outputs.append(output) else: python_var_value = None - for scope in self._locals: # TODO: skip _current_scope + for scope in reversed(self._locals): # TODO: skip _current_scope if python_var in scope: python_var_value = scope[python_var] break From 59a9eaac4a5c8fa24043e6ce183de5ecd6d47ef8 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 21 Jan 2026 12:45:32 -0800 Subject: [PATCH 15/31] Bump ir version Signed-off-by: Justin Chu --- noxfile.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/noxfile.py b/noxfile.py index 71d28f8edd..2bac0c4373 100644 --- a/noxfile.py +++ b/noxfile.py @@ -41,7 +41,7 @@ "packaging", "protobuf", ) -ONNX_IR = "onnx_ir==0.1.13" +ONNX_IR = "onnx_ir==0.1.15" ONNX_IR_MAIN = "git+https://github.com/onnx/ir-py.git@main#egg=onnx_ir" diff --git a/pyproject.toml b/pyproject.toml index 37cdf3d4ea..1f9025fb9a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ classifiers = [ dependencies = [ "ml_dtypes", "numpy", - "onnx_ir>=0.1.13,<2", # Expect onnx_ir to have a breaking change in 2.0. If not, extend this range. + "onnx_ir>=0.1.15,<2", # Expect onnx_ir to have a breaking change in 2.0. If not, extend this range. "onnx>=1.17", "packaging", "typing_extensions>=4.10", From dd2396dac5a116901a78eb295840e049164c93f5 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 21 Jan 2026 14:46:14 -0800 Subject: [PATCH 16/31] wip Signed-off-by: Justin Chu --- onnxscript/_internal/converter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/_internal/converter.py b/onnxscript/_internal/converter.py index 31de69b799..f8b26ffb8b 100644 --- a/onnxscript/_internal/converter.py +++ b/onnxscript/_internal/converter.py @@ -1109,7 +1109,7 @@ def check_num_outputs(n): def ret(exp, i, suffix): preferred_name = f"return_val{suffix}" - return_var = self._translate_expr(exp, preferred_name) # TODO(rama) + return_var = self._translate_expr(exp, preferred_name) val = self._lookup(return_var.name, self._source_of(exp), raise_exception=False) if isinstance(val, values.SymbolValue) and isinstance(val.value, ir.Value): if val.value.is_graph_input(): From 9d933ae0deb1a31e65886ea34bbefc64aa6897d1 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 21 Jan 2026 18:51:08 -0800 Subject: [PATCH 17/31] Fix cond Signed-off-by: Justin Chu --- onnxscript/_internal/converter.py | 50 +++++++++++++------------------ 1 file changed, 21 insertions(+), 29 deletions(-) diff --git a/onnxscript/_internal/converter.py b/onnxscript/_internal/converter.py index f8b26ffb8b..59d53004ff 100644 --- a/onnxscript/_internal/converter.py +++ b/onnxscript/_internal/converter.py @@ -1203,7 +1203,11 @@ def _translate_loop_stmt(self, loop_stmt: ast.For | ast.While) -> None: self.fail(loop_stmt, "Unsupported loop bound, it should be 'range(?)'.") assert not iter.keywords, "Unsupported loop bound." o_loop_bound = self._translate_expr(iter.args[0], "loop_bound") - onnx_cond_var = ir.Value(name=self.generate_unique_name("cond_in")) # TODO(Rama) + onnx_cond_var = make_value( + self.generate_unique_name("cond_in"), + onnx_types.BOOL, + self._source_of(loop_stmt), + ) i_cond_var = onnx_cond_var cond_while = None o_loop_condition = None # No condition for a for loop. @@ -1217,8 +1221,16 @@ def _translate_loop_stmt(self, loop_stmt: ast.For | ast.While) -> None: ) python_loop_var_name = "infinite_loop" o_loop_bound = None - i_cond_var = ir.Value(name=test.id) # TODO(Rama) - cond_while = ir.Value(name=test.id) # TODO(Rama) + i_cond_var = make_value( + self.generate_unique_name(test.id), + onnx_types.BOOL, + self._source_of(loop_stmt), + ) + cond_while = make_value( + self.generate_unique_name(test.id), + onnx_types.BOOL, + self._source_of(loop_stmt), + ) onnx_cond_var = None o_loop_condition = self._translate_name_expr(test) # we need to go through all the instructions to see @@ -1251,20 +1263,11 @@ def _translate_loop_stmt(self, loop_stmt: ast.For | ast.While) -> None: values.SymbolValue(onnx_loop_var, self._source_of(loop_stmt)), ) - self._current_fn.append_parameter( - make_value( - i_cond_var.name, - onnx_types.BOOL, - self._source_of(loop_stmt), - ) - ) + self._current_fn.append_parameter(i_cond_var) for pv in loop_state_vars: onnx_var_name = self.generate_unique_name(pv) - # TODO: retrieve the annotation for variable pv is any is specified. - # typeinfo = self._eval_constant_expr(pv.annotation) - typeinfo = None - parameter = make_value(onnx_var_name, typeinfo, self._source_of(loop_stmt)) + parameter = make_value(onnx_var_name, None, self._source_of(loop_stmt)) self._current_fn.append_parameter(parameter) self._bind( pv, @@ -1303,8 +1306,6 @@ def _translate_loop_stmt(self, loop_stmt: ast.For | ast.While) -> None: continue self._translate_stmt(s) - onnx_cond_out_name = self.generate_unique_name("cond_out") - if cond_while is not None: # Loop while current_scope = self._current_scope() @@ -1316,20 +1317,15 @@ def _translate_loop_stmt(self, loop_stmt: ast.For | ast.While) -> None: ) onnx_cond_var = current_scope[cond_while.name].value - self.emit( + onnx_cond_out_name = self.generate_unique_name("cond_out") + cond_out = self.emit( [onnx_cond_out_name], values.Op(self.default_opset, operator_name), [condition_name or onnx_cond_var], [], ) + self._current_fn.outputs.append(cond_out) - self._current_fn.outputs.append( - make_value( - onnx_cond_out_name, - onnx_types.BOOL, - self._source_of(loop_stmt), - ) - ) for pv in loop_state_vars: onnx_var = self._py_var_to_onnx_var(pv, self._source_of(loop_stmt)) if onnx_var.name not in self._current_fn.assigned_names: @@ -1339,11 +1335,7 @@ def _translate_loop_stmt(self, loop_stmt: ast.For | ast.While) -> None: # In this case, we create a copy of y, treating the statement as # shorthand for "x = op.Identity(y)". onnx_var = self._emit_copy(onnx_var, pv) - # TODO: retrieve variable type for the annotation if any. - typeinfo = None - self._current_fn.outputs.append( - make_value(onnx_var.name, typeinfo, self._source_of(loop_stmt)) - ) + self._current_fn.outputs.append(onnx_var) body = self._exit_scope() inputs = [o_loop_bound, o_loop_condition] + [ self._py_var_to_onnx_var(pv, self._source_of(loop_stmt)) for pv in loop_state_vars From bd09ed39ccefdda1e0d921e58b90590c2fc21b89 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 22 Jan 2026 13:20:04 -0800 Subject: [PATCH 18/31] Get opset version Signed-off-by: Justin Chu --- onnxscript/_internal/values.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/onnxscript/_internal/values.py b/onnxscript/_internal/values.py index 984b7ee26d..f4fe58cdcd 100644 --- a/onnxscript/_internal/values.py +++ b/onnxscript/_internal/values.py @@ -297,6 +297,8 @@ def __init__( super().__init__(opset, irfun.name) self.function = pyfun self.function_ir = irfun + # Record the function domain's opset version in the function metadata + self.function_ir.meta["opset_version"] = opset.version self.source = source self.kwargs = kwargs self._signature = _schemas.OpSignature.from_function( @@ -381,7 +383,9 @@ def _to_model_proto( if functions is None: # Identify functions to include in the model sub_functions = self.function_ir.get_called_functions() - ir_functions = sub_functions.values() + ir_functions: list[ir.Function] = [ + func.function_ir for func in sub_functions.values() + ] else: ir_functions = [] for func in functions: @@ -404,12 +408,12 @@ def _to_model_proto( opset_imports = main_graph.opset_imports for func in ir_functions: - domain = func.opset.domain + domain = func.domain if domain is not None and domain not in opset_imports: - opset_imports[domain] = func.opset.version + opset_imports[domain] = func.meta.get("opset_version", 1) - if "" not in opset_imports and "" in func.function_ir.opset_imports: - opset_imports[""] = func.function_ir.opset_imports[""] + if "" not in opset_imports and "" in func.opset_imports: + opset_imports[""] = func.opset_imports[""] if "" not in opset_imports: # No operator is using the standard opset. From 4d92d81c4fedf788a7fd193fb79acda73cdf2663 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 22 Jan 2026 13:21:27 -0800 Subject: [PATCH 19/31] update Signed-off-by: Justin Chu --- onnxscript/_internal/values.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/_internal/values.py b/onnxscript/_internal/values.py index f4fe58cdcd..ca34c36f0f 100644 --- a/onnxscript/_internal/values.py +++ b/onnxscript/_internal/values.py @@ -431,7 +431,7 @@ def _to_model_proto( # Create the model model = ir.Model(main_graph, ir_version=ir_version) for func in ir_functions: - model.functions[func.function_ir.identifier()] = func.function_ir + model.functions[func.identifier()] = func model_proto = ir.to_proto(model) From f2f747aa1dd4c0d69d57d32f6f5fa14156c69c95 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 22 Jan 2026 18:25:04 -0800 Subject: [PATCH 20/31] Handle graph attributes Signed-off-by: Justin Chu --- onnxscript/_internal/autocast.py | 43 ++++++++------------------ onnxscript/_internal/converter.py | 51 +++++++++---------------------- onnxscript/_internal/evaluator.py | 8 +++-- 3 files changed, 34 insertions(+), 68 deletions(-) diff --git a/onnxscript/_internal/autocast.py b/onnxscript/_internal/autocast.py index 99d0e82f5b..ddccc44ade 100644 --- a/onnxscript/_internal/autocast.py +++ b/onnxscript/_internal/autocast.py @@ -6,8 +6,6 @@ from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence import numpy as np -import onnx -import onnx.helper # noqa: TID251 from onnxscript import ir, tensor from onnxscript.ir import _schemas @@ -20,23 +18,15 @@ # python values into ONNX TensorProto, while the runtime converts python values into # ONNXScript runtime's value-representation (based on Tensor). - -# Utilities to convert a python value to TensorProto (for use by the script converter) - - -def pyvalue_to_onnx_tensor(tensor_name: str, pyvalue): - return ir.serde.serialize_tensor(ir.tensor(pyvalue, name=tensor_name)) - - _REPEATED_ATTRIBUTE_TYPES = frozenset( { - onnx.AttributeProto.FLOATS, - onnx.AttributeProto.INTS, - onnx.AttributeProto.STRINGS, - onnx.AttributeProto.TENSORS, - onnx.AttributeProto.GRAPHS, - onnx.AttributeProto.SPARSE_TENSORS, - onnx.AttributeProto.TYPE_PROTOS, + ir.AttributeType.FLOATS, + ir.AttributeType.INTS, + ir.AttributeType.STRINGS, + ir.AttributeType.TENSORS, + ir.AttributeType.GRAPHS, + ir.AttributeType.SPARSE_TENSORS, + ir.AttributeType.TYPE_PROTOS, } ) @@ -45,33 +35,26 @@ def pyvalue_to_onnx_attribute( key: str, value: Any, name_generator: Callable[[], str], - attr_type: onnx.AttributeProto.AttributeType | None = None, -) -> onnx.AttributeProto: + attr_type: ir.AttributeType | None = None, +) -> ir.Attr: """Helper function to create an ONNX AttributeProto. - This is a refinement of onnx.helper.make_attribute that works with ONNX Script - conventions for allowed types for attribute-values. In particular, it allows - * Empty lists as attribute values, provided the attribute type is specified + * Empty lists can be attribute values, provided the attribute type is specified and is a list type. * Scalar-values like 1.0 as well as lists like [1, -1] to be specified when the attribute type is TensorProto by automatically converting the value into a 0-D or 1-D tensor respectively. """ + # TODO(justinchuby): Remove this function and use onnx-ir directly. if isinstance(value, list) and not value: # Empty list value: if attr_type is None: raise ValueError("Attribute type must be specified for empty list value.") if attr_type not in _REPEATED_ATTRIBUTE_TYPES: raise ValueError("Empty list value is only allowed for repeated attribute types.") - return onnx.AttributeProto(name=key, type=attr_type) - elif attr_type == onnx.AttributeProto.TENSOR and not isinstance(value, onnx.TensorProto): - return onnx.AttributeProto( - name=key, type=attr_type, t=pyvalue_to_onnx_tensor(name_generator(), value) - ) + return ir.Attr(name=key, type=attr_type, value=[]) else: - # When the value is a subgraph, ONNX IR will complain that some values are - # not found from the scope. - return onnx.helper.make_attribute(key, value) # noqa: TID251 + return ir.convenience.convert_attribute(key, value, attr_type=attr_type) # Utilities to convert python values into onnxscript tensors. diff --git a/onnxscript/_internal/converter.py b/onnxscript/_internal/converter.py index cea5b58505..bc70658bfd 100644 --- a/onnxscript/_internal/converter.py +++ b/onnxscript/_internal/converter.py @@ -320,21 +320,6 @@ def generate_unique_name(self, candidate: str = "tmp") -> str: self._used_vars.add(r) return r - def _make_onnx_attr( - self, attrname: str, attrval: Any, attrtype: int | None = None - ) -> ir.Attr: - if isinstance(attrval, ir.Graph): - return ir.Attr(attrname, ir.AttributeType.GRAPH, attrval) - - def tensor_name_generator() -> str: - """Return name to be used for tensor, if we need to create one.""" - return self.generate_unique_name(f"attr_{attrname}") - - proto = autocast.pyvalue_to_onnx_attribute( - attrname, attrval, tensor_name_generator, attrtype - ) - return ir.from_proto(proto) - def _to_onnx_attr_ref( self, val: values.AttrRef, info: sourceinfo.SourceInfo | None ) -> ir.Attr: @@ -371,7 +356,7 @@ def _to_onnx_var( # distinguish between int and bool. So we cast the int tensor to a bool tensor, # to promote a (python) bool attribute to a ONNX bool tensor. result_as_bool = self.generate_unique_name(result_name + "_as_bool") - cast_attr = self._make_onnx_attr("to", onnx_types.BOOL.dtype) + cast_attr = ir.AttrInt64("to", onnx_types.BOOL.dtype) self._castable.add(result_as_bool) return self.emit1( [result_as_bool], @@ -448,11 +433,8 @@ def _emit_const( else: suggested_name = "const" ovar = self.generate_unique_name(suggested_name) - try: - tensor = autocast.pyvalue_to_onnx_tensor(ovar, pyvalue) - except ValueError as e: - fail(info.msg(str(e))) - attr = self._make_onnx_attr("value", tensor) + + attr = ir.AttrTensor("value", ir.tensor(pyvalue, name=ovar)) self._castable.add(ovar) return self.emit1([ovar], values.Op(self.default_opset, "Constant"), [], [attr]) @@ -581,13 +563,9 @@ def _translate_attr( if attr_meta and attr_meta.required: self.fail(expr, f"Attribute '{attr_name}' is required.") return None - attr_type = int(attr_meta.type) if attr_meta else None - attr = self._make_onnx_attr(attr_name, val, attrtype=attr_type) - if attr_meta and (attr.type != attr_meta.type): - self.fail( - expr, - f"Attribute type '{attr.type}' does not match expected type '{attr_meta.type}'", - ) + attr = ir.Attr( + attr_name, attr_meta.type if attr_meta else ir.AttributeType.UNDEFINED, val + ) return attr def _translate_docstring(self, node: ast.Expr) -> None: @@ -805,7 +783,7 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[ir.Value, ir.Value, ir.Value steps.append(inputs[2]) if len(starts) > 1: - axis_0_attr = self._make_onnx_attr("axis", 0) + axis_0_attr = ir.AttrInt64("axis", 0) start_name = self.generate_unique_name(f"{var_name}_start") start_value = self.emit([start_name], "Concat", starts, [axis_0_attr]) @@ -855,7 +833,7 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[ir.Value, ir.Value, ir.Value last_axis = None for axis, index_expr in non_scalar_indices: index_value = self._translate_expr(index_expr) - axis_attr = self._make_onnx_attr("axis", axis) + axis_attr = ir.AttrInt64("axis", axis) # use Gather to perform indexing # Assign gathered value to either temporary or final target if axis != last_axis: # use temporary to store result of Gather @@ -901,19 +879,20 @@ def _translate_binary_op_expr(self, node: ast.BinOp): if op not in primop_map: raise ValueError(self._message(node, f"Unsupported operator {op!r}.")) - attr = [] if isinstance(node.op, ast.Mod) and self._is_constant_expr(node.right): # specific case X % f where f is a float. # attribute fmod=1 is added in that case. cst = self._eval_constant_expr(node.right) if isinstance(cst, float): - attr = [self._make_onnx_attr("fmod", 1)] + attrs = [ir.AttrInt64("fmod", 1)] + else: + attrs = [] op = values.Op(self.default_opset, primop_map[op]) left, right = self._cast_like_binary_expression( op, self._translate_expr(node.left), self._translate_expr(node.right) ) - return op, [left, right], attr + return op, [left, right], attrs def _translate_unary_op_expr(self, node): op = type(node.op) @@ -1156,9 +1135,9 @@ def _translate_if_stmt(self, stmt: ast.If) -> None: test = self._translate_expr(stmt.test, "cond") lineno = self._source_of(stmt).lineno thenGraph = self._translate_block(stmt.body, f"thenGraph_{lineno}", live_defs) - thenAttr = self._make_onnx_attr("then_branch", thenGraph) + thenAttr = ir.AttrGraph("then_branch", thenGraph) elseGraph = self._translate_block(stmt.orelse, f"elseGraph_{lineno}", live_defs) - elseAttr = self._make_onnx_attr("else_branch", elseGraph) + elseAttr = ir.AttrGraph("else_branch", elseGraph) def rename(x): return self.generate_unique_name(x) @@ -1335,7 +1314,7 @@ def _translate_loop_stmt(self, loop_stmt: ast.For | ast.While) -> None: inputs = [o_loop_bound, o_loop_condition] + [ self._py_var_to_onnx_var(pv, self._source_of(loop_stmt)) for pv in loop_state_vars ] - attrs = [self._make_onnx_attr("body", body.graph)] + attrs = [ir.AttrGraph("body", body.graph)] info = self._source_of(loop_stmt) def rename(x): diff --git a/onnxscript/_internal/evaluator.py b/onnxscript/_internal/evaluator.py index d6075d29be..9252996f99 100644 --- a/onnxscript/_internal/evaluator.py +++ b/onnxscript/_internal/evaluator.py @@ -21,6 +21,7 @@ import onnx import onnx.defs import onnx.reference +import onnx_ir as ir from typing_extensions import TypeAlias from onnxscript import onnx_opset, tensor @@ -418,12 +419,15 @@ def _prepare_model_and_inputs_for_eager( implicit_args = {k: _onnxscript_to_numpy_value(v) for k, v in implicit_args.items()} # Utility to convert kwarg to ONNX AttributeProto: + # TODO(justinchuby): Clean up this function to use onnx-ir def make_attr(key: str, value: Any) -> onnx.AttributeProto: def make_tensor_name() -> str: return f"attr_{key}" - return autocast.pyvalue_to_onnx_attribute( - key, value, make_tensor_name, int(schema.attributes[key].type) + return ir.to_proto( + autocast.pyvalue_to_onnx_attribute( + key, value, make_tensor_name, int(schema.attributes[key].type) + ) ) # Construct ONNX model with a single op call: From bcaec1b3425da9b177ce0c7bca97da2106c0bd0d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 22 Jan 2026 18:37:25 -0800 Subject: [PATCH 21/31] Clean ta Signed-off-by: Justin Chu --- onnxscript/_internal/autocast.py | 3 + onnxscript/_internal/converter.py | 11 +- onnxscript/_internal/type_annotation.py | 79 --------- onnxscript/_internal/type_annotation_test.py | 159 +------------------ onnxscript/ir/_schemas.py | 4 +- 5 files changed, 12 insertions(+), 244 deletions(-) diff --git a/onnxscript/_internal/autocast.py b/onnxscript/_internal/autocast.py index ddccc44ade..8b14fa2932 100644 --- a/onnxscript/_internal/autocast.py +++ b/onnxscript/_internal/autocast.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence import numpy as np +import onnx from onnxscript import ir, tensor from onnxscript.ir import _schemas @@ -53,6 +54,8 @@ def pyvalue_to_onnx_attribute( if attr_type not in _REPEATED_ATTRIBUTE_TYPES: raise ValueError("Empty list value is only allowed for repeated attribute types.") return ir.Attr(name=key, type=attr_type, value=[]) + elif attr_type == ir.AttributeType.TENSOR and not isinstance(value, onnx.TensorProto): + return ir.AttrTensor(name=key, value=ir.tensor(value, name=name_generator())) else: return ir.convenience.convert_attribute(key, value, attr_type=attr_type) diff --git a/onnxscript/_internal/converter.py b/onnxscript/_internal/converter.py index bc70658bfd..84891166a5 100644 --- a/onnxscript/_internal/converter.py +++ b/onnxscript/_internal/converter.py @@ -563,9 +563,10 @@ def _translate_attr( if attr_meta and attr_meta.required: self.fail(expr, f"Attribute '{attr_name}' is required.") return None - attr = ir.Attr( - attr_name, attr_meta.type if attr_meta else ir.AttributeType.UNDEFINED, val - ) + attr_type = attr_meta.type if attr_meta else ir.AttributeType.UNDEFINED + if attr_type == ir.AttributeType.TENSOR: + val = ir.tensor(val) + attr = ir.Attr(attr_name, attr_type, val) return attr def _translate_docstring(self, node: ast.Expr) -> None: @@ -1411,8 +1412,8 @@ def _translate_function_signature_common( # The code can only be exported as a function. typeinfo = None if typeinfo and ta.is_attr_type(typeinfo): - attribute_type = ta.pytype_to_attrtype(typeinfo) - attr = ir.Attr(x.arg, ir.AttributeType(attribute_type), default_value, None) + attribute_type = _schemas.get_attr_type(typeinfo) + attr = ir.Attr(x.arg, attribute_type, default_value, None) self._current_fn.append_parameter(attr) as_bool = ta.base_type_is_bool(typeinfo) self._bind(x.arg, values.AttrRef(attr, as_bool, self._source_of(x))) diff --git a/onnxscript/_internal/type_annotation.py b/onnxscript/_internal/type_annotation.py index fb7b8a370d..78e6d78343 100644 --- a/onnxscript/_internal/type_annotation.py +++ b/onnxscript/_internal/type_annotation.py @@ -215,82 +215,3 @@ def get_return_types(typeinfo: type | Sequence[type]) -> Sequence[type]: if typing.get_origin(typeinfo) is tuple: return typing.get_args(typeinfo) return (typeinfo,) - - -def pytype_to_type_strings(pytype: TypeAnnotationValue) -> list[str]: - """Returns a list of type-strings corresponding to a given type annotation. - - Args: - pytype: A type annotation. - - Returns: - A list of all supported input types for the given type annotation. - Ensures that the list is sorted in the same order as ALL_TYPE_STRINGS. - """ - if pytype is None: - return list(ALL_TENSOR_TYPE_STRINGS) - if pytype is onnx_types.TensorType: - return list(ALL_TENSOR_TYPE_STRINGS) - if isinstance(pytype, type) and issubclass(pytype, onnx_types.TensorType): - return [pytype.to_string()] - if isinstance(pytype, onnx_types.TensorType): - return [pytype.to_string()] - if isinstance(pytype, typing.TypeVar): - constraints = pytype.__constraints__ - if constraints: - return pytype_to_type_strings(Union.__getitem__(constraints)) # pylint: disable=unnecessary-dunder-call - bound = pytype.__bound__ - if bound is None: - return list(ALL_TENSOR_TYPE_STRINGS) - return pytype_to_type_strings(bound) - if typing.get_origin(pytype) is Union: - options = [] - subtypes = typing.get_args(pytype) - # A None type in a Union is equivalent to an optional type - optional = is_optional(pytype) - for subtype in subtypes: - if subtype is type(None): - # Skip None type because we are handling it with is_optional - continue - if optional: - options += [ - *pytype_to_type_strings(subtype), - *[f"optional({s})" for s in pytype_to_type_strings(subtype)], - ] - else: - options += pytype_to_type_strings(subtype) - # Remove duplicates - return sorted(set(options)) - if typing.get_origin(pytype) in _LIST_CONSTRUCTORS: - subtypes = typing.get_args(pytype) - return [f"seq({s})" for s in pytype_to_type_strings(subtypes[0])] - - raise ValueError(f"Unsupported type: {pytype}") - - -def get_type_constraint_name(pytype: TypeAnnotationValue) -> Optional[str]: - """Returns the name of the type constraint for a given type annotation. - - Args: - pytype: A type annotation. - - Returns: - The name of the type constraint if it is a TypeVar. - - Prefixes the name with "Optional_" if the type annotation is Optional[TypeVar]. - - Prefixes the name with "Sequence_" if the type annotation is a Sequence[]. - - Returns None if the type annotation does not have a type constraint. - """ - if isinstance(pytype, typing.TypeVar): - return pytype.__name__ - if is_optional(pytype): - subtypes = typing.get_args(pytype) - for subtype in subtypes: - if subtype is type(None): - continue - type_param_name = get_type_constraint_name(subtype) - return f"Optional_{type_param_name}" if type_param_name else None - if typing.get_origin(pytype) in _LIST_CONSTRUCTORS: - subtypes = typing.get_args(pytype) - type_param_name = get_type_constraint_name(subtypes[0]) - return f"Sequence_{type_param_name}" if type_param_name else None - return None diff --git a/onnxscript/_internal/type_annotation_test.py b/onnxscript/_internal/type_annotation_test.py index 259157e66d..622a04339a 100644 --- a/onnxscript/_internal/type_annotation_test.py +++ b/onnxscript/_internal/type_annotation_test.py @@ -2,14 +2,9 @@ # Licensed under the MIT License. import unittest -from typing import Any, List, Optional, Sequence, TypeVar, Union -import parameterized - -import onnxscript import onnxscript.testing -from onnxscript import FLOAT, INT64, script -from onnxscript._internal import type_annotation +from onnxscript import FLOAT, script from onnxscript.onnx_opset import opset15 as op from tests.common import testutils @@ -90,157 +85,5 @@ def bool_type_for_attribute(self: FLOAT[...], sorted: bool) -> FLOAT[...]: ) -_TestTypeVarConstraints = TypeVar("_TestTypeVarConstraints", INT64, FLOAT) -_TestTypeVarOneBound = TypeVar("_TestTypeVarOneBound", bound=INT64) -_TestTypeVarTwoBound = TypeVar("_TestTypeVarTwoBound", bound=Union[INT64, FLOAT]) - - -class TypeConversionFunctionsTest(unittest.TestCase): - @parameterized.parameterized.expand( - [ - ( - "tensor_type_all", - onnxscript.onnx_types.TensorType, - list(type_annotation.ALL_TENSOR_TYPE_STRINGS), - ), - ("none", None, list(type_annotation.ALL_TENSOR_TYPE_STRINGS)), - ("tensor_type", INT64, ["tensor(int64)"]), - ("tensor_type_union", Union[INT64, FLOAT], ["tensor(float)", "tensor(int64)"]), - ("tensor_type_variadic_shape", INT64[...], ["tensor(int64)"]), - ("tensor_type_shape", INT64[10], ["tensor(int64)"]), - ( - "type_var_constraints", - _TestTypeVarConstraints, - ["tensor(float)", "tensor(int64)"], - ), - ("type_bound_one", _TestTypeVarOneBound, ["tensor(int64)"]), - ("type_bound_two", _TestTypeVarTwoBound, ["tensor(float)", "tensor(int64)"]), - ( - "optional_tensor_type_all", - Optional[onnxscript.onnx_types.TensorType], - [ - *[ - f"optional({tensor_type})" - for tensor_type in type_annotation.ALL_TENSOR_TYPE_STRINGS - ], - *type_annotation.ALL_TENSOR_TYPE_STRINGS, - ], - ), - ( - "optional_tensor_type", - Optional[INT64], - ["optional(tensor(int64))", "tensor(int64)"], - ), - ( - "optional_tensor_type_union", - Optional[Union[INT64, FLOAT]], - [ - "optional(tensor(float))", - "optional(tensor(int64))", - "tensor(float)", - "tensor(int64)", - ], - ), - ( - "optional_tensor_type_variadic_shape", - Optional[INT64[...]], - ["optional(tensor(int64))", "tensor(int64)"], - ), - ( - "optional_tensor_type_shape", - Optional[INT64[10]], - ["optional(tensor(int64))", "tensor(int64)"], - ), - ( - "optional_type_var_constraints", - Optional[_TestTypeVarConstraints], - [ - "optional(tensor(float))", - "optional(tensor(int64))", - "tensor(float)", - "tensor(int64)", - ], - ), - ( - "optional_type_bound_one", - Optional[_TestTypeVarOneBound], - ["optional(tensor(int64))", "tensor(int64)"], - ), - ( - "optional_type_bound_two", - Optional[_TestTypeVarTwoBound], - [ - "optional(tensor(float))", - "optional(tensor(int64))", - "tensor(float)", - "tensor(int64)", - ], - ), - ( - "sequence_type_all", - Sequence[onnxscript.onnx_types.TensorType], - [ - f"seq({tensor_type})" - for tensor_type in type_annotation.ALL_TENSOR_TYPE_STRINGS - ], - ), - ("sequence_type", Sequence[INT64], ["seq(tensor(int64))"]), - ( - "union_sequence_type", - Union[Sequence[INT64], Sequence[FLOAT]], - ["seq(tensor(float))", "seq(tensor(int64))"], - ), - ( - "sequence_type_variadic_shape", - Sequence[INT64[...]], - ["seq(tensor(int64))"], - ), - ("sequence_type_shape", Sequence[INT64[10]], ["seq(tensor(int64))"]), - ( - "sequence_type_var_constraints", - Sequence[_TestTypeVarConstraints], - ["seq(tensor(float))", "seq(tensor(int64))"], - ), - ( - "sequence_type_bound_one", - Sequence[_TestTypeVarOneBound], - ["seq(tensor(int64))"], - ), - ( - "sequence_type_bound_two", - Sequence[_TestTypeVarTwoBound], - ["seq(tensor(float))", "seq(tensor(int64))"], - ), - ] - ) - def test_pytype_to_type_strings(self, _, pytype: Any, expected: List[str]): - self.assertEqual(type_annotation.pytype_to_type_strings(pytype), expected) - - @parameterized.parameterized.expand( - [ - ("type_var", _TestTypeVarConstraints, "_TestTypeVarConstraints"), - ("type_var_bound", _TestTypeVarOneBound, "_TestTypeVarOneBound"), - ( - "optional_type_var", - Optional[_TestTypeVarOneBound], - "Optional__TestTypeVarOneBound", - ), - ( - "sequence_type_var", - Sequence[_TestTypeVarOneBound], - "Sequence__TestTypeVarOneBound", - ), - ("normal_type", INT64, None), - ("union_type", Union[INT64, FLOAT], None), - ("optional_type", Optional[INT64], None), - ("sequence_type", Sequence[INT64], None), - ("optional_sequence_type", Optional[Sequence[INT64]], None), - ("optional_union_type", Optional[Union[INT64, FLOAT]], None), - ] - ) - def test_get_type_constraint_name(self, _: str, pytype: Any, expected: Optional[str]): - self.assertEqual(type_annotation.get_type_constraint_name(pytype), expected) - - if __name__ == "__main__": unittest.main() diff --git a/onnxscript/ir/_schemas.py b/onnxscript/ir/_schemas.py index 1f14634834..806980f7dc 100644 --- a/onnxscript/ir/_schemas.py +++ b/onnxscript/ir/_schemas.py @@ -214,7 +214,7 @@ def _is_optional(type_: type) -> bool: return False -def _get_attr_type(type_: type) -> ir.AttributeType: +def get_attr_type(type_: type) -> ir.AttributeType: """Obtain the type of the attribute from a Python class.""" try: if type_ in _PY_TYPE_TO_ATTR_TYPE: @@ -467,7 +467,7 @@ def from_function( ) else: type_ = type_hints[param.name] - if (attr_type := _get_attr_type(type_)) != ir.AttributeType.UNDEFINED: + if (attr_type := get_attr_type(type_)) != ir.AttributeType.UNDEFINED: # Construct the default attribute if param.default is not inspect.Parameter.empty: # TODO: Use ir_convenience instead to handle int as float From 078db3466d95e14f0d8d0ca9f8156b6b999a83c6 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 22 Jan 2026 18:46:28 -0800 Subject: [PATCH 22/31] Fix graph attributes Signed-off-by: Justin Chu --- onnxscript/_internal/converter.py | 4 +-- onnxscript/_internal/converter_test.py | 2 +- onnxscript/_internal/type_annotation.py | 40 ++----------------------- 3 files changed, 6 insertions(+), 40 deletions(-) diff --git a/onnxscript/_internal/converter.py b/onnxscript/_internal/converter.py index 84891166a5..947741f8b9 100644 --- a/onnxscript/_internal/converter.py +++ b/onnxscript/_internal/converter.py @@ -540,8 +540,8 @@ def _translate_attr( f"Outer scope variable '{pyvar}' referenced by function " f"'{expr.id!r}' modified.", ) - # Create GraphProto attribute - val = irfunction.to_graph_proto() + + val = irfunction.graph if isinstance(val, ir.Value): self.fail(expr, f"Cannot use ir.Value '{expr.id}' as an attribute.") else: diff --git a/onnxscript/_internal/converter_test.py b/onnxscript/_internal/converter_test.py index e825697d8e..c1fe276ef5 100644 --- a/onnxscript/_internal/converter_test.py +++ b/onnxscript/_internal/converter_test.py @@ -291,7 +291,7 @@ def test_renaming(self): self.validate_save(renaming, shape_inference=False) @pytest.mark.xfail( - strict=True, + strict=False, reason="optional output is not yet implemented", ) def test_opt_output(self): diff --git a/onnxscript/_internal/type_annotation.py b/onnxscript/_internal/type_annotation.py index 78e6d78343..a5509f92b6 100644 --- a/onnxscript/_internal/type_annotation.py +++ b/onnxscript/_internal/type_annotation.py @@ -22,22 +22,8 @@ # - Above types with annotation attached TypeAnnotationValue = typing.Any -# Map from python type to corresponding ONNX AttributeProto type -_PYTYPE_TO_ATTRTYPE_MAP = { - float: onnx.AttributeProto.FLOAT, - int: onnx.AttributeProto.INT, - str: onnx.AttributeProto.STRING, - bool: onnx.AttributeProto.INT, # experimental -} -# Map from python type to corresponding ONNX AttributeProto type, -# for repeated (i.e., list of) values -_LISTTYPE_TO_ATTRTYPE_MAP = { - float: onnx.AttributeProto.FLOATS, - int: onnx.AttributeProto.INTS, - str: onnx.AttributeProto.STRINGS, - bool: onnx.AttributeProto.INTS, # experimental -} +_PRIMITIVE_ATTR_TYPE = frozenset((float, int, str, bool)) _LIST_CONSTRUCTORS = frozenset([list, typing.List, typing.Sequence, collections.abc.Sequence]) @@ -92,33 +78,13 @@ def _remove_annotation(typeinfo: TypeAnnotationValue) -> TypeAnnotationValue: def _is_primitive_attr_type(typeinfo: TypeAnnotationValue) -> bool: - return typeinfo in _PYTYPE_TO_ATTRTYPE_MAP - - -def pytype_to_attrtype( - pytype: TypeAnnotationValue, -) -> Optional[onnx.AttributeProto.AttributeType]: - pytype = _remove_annotation(pytype) - if pytype in _PYTYPE_TO_ATTRTYPE_MAP: - return _PYTYPE_TO_ATTRTYPE_MAP[pytype] - type_constructor = typing.get_origin(pytype) - # Remove Optional wrapper if present, which is represented as an Union[..., type(None)] - if type_constructor is typing.Union: - # Filter out type(None), since typing.Optional[X] evaluates to Union[X, type(None)] - args = [x for x in typing.get_args(pytype) if x is not type(None)] - if len(args) == 1: - return pytype_to_attrtype(args[0]) - if type_constructor in _LIST_CONSTRUCTORS: - elt_type = typing.get_args(pytype)[0] - if elt_type in _LISTTYPE_TO_ATTRTYPE_MAP: - return _LISTTYPE_TO_ATTRTYPE_MAP[elt_type] - return None + return typeinfo in _PRIMITIVE_ATTR_TYPE def base_type_is_bool(pytype: TypeAnnotationValue) -> bool: """Returns True if base type of pytype is bool, False otherwise.""" pytype = _remove_annotation(pytype) - if pytype in _PYTYPE_TO_ATTRTYPE_MAP: + if pytype in _PRIMITIVE_ATTR_TYPE: return pytype is bool type_constructor = typing.get_origin(pytype) if type_constructor in _LIST_CONSTRUCTORS: From d3b7fd8f3e30bca107a5976fb24a4cccc2874b63 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 22 Jan 2026 18:50:38 -0800 Subject: [PATCH 23/31] Fix Signed-off-by: Justin Chu --- onnxscript/_internal/converter.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxscript/_internal/converter.py b/onnxscript/_internal/converter.py index 947741f8b9..18b3b56721 100644 --- a/onnxscript/_internal/converter.py +++ b/onnxscript/_internal/converter.py @@ -880,14 +880,13 @@ def _translate_binary_op_expr(self, node: ast.BinOp): if op not in primop_map: raise ValueError(self._message(node, f"Unsupported operator {op!r}.")) + attrs = [] if isinstance(node.op, ast.Mod) and self._is_constant_expr(node.right): # specific case X % f where f is a float. # attribute fmod=1 is added in that case. cst = self._eval_constant_expr(node.right) if isinstance(cst, float): attrs = [ir.AttrInt64("fmod", 1)] - else: - attrs = [] op = values.Op(self.default_opset, primop_map[op]) left, right = self._cast_like_binary_expression( From 1eefa02aba97aeb04508f99154e570e04a88d0eb Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 22 Jan 2026 18:58:34 -0800 Subject: [PATCH 24/31] Fix attr conversion Signed-off-by: Justin Chu --- onnxscript/_internal/converter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/_internal/converter.py b/onnxscript/_internal/converter.py index 18b3b56721..7c279c60e0 100644 --- a/onnxscript/_internal/converter.py +++ b/onnxscript/_internal/converter.py @@ -563,10 +563,10 @@ def _translate_attr( if attr_meta and attr_meta.required: self.fail(expr, f"Attribute '{attr_name}' is required.") return None - attr_type = attr_meta.type if attr_meta else ir.AttributeType.UNDEFINED + attr_type = attr_meta.type if attr_meta else None if attr_type == ir.AttributeType.TENSOR: val = ir.tensor(val) - attr = ir.Attr(attr_name, attr_type, val) + attr = ir.convenience.convert_attribute(attr_name, val, attr_type) return attr def _translate_docstring(self, node: ast.Expr) -> None: From e7b353de9131dcec48b4ef1bdc52c0df3edffc10 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 22 Jan 2026 22:49:37 -0800 Subject: [PATCH 25/31] update Signed-off-by: Justin Chu --- onnxscript/ir/_schemas.py | 10 +++++ tests/function_libs/torch_lib/ops_test.py | 15 +------ .../torch_lib/ops_test_common.py | 45 +++++++++---------- 3 files changed, 33 insertions(+), 37 deletions(-) diff --git a/onnxscript/ir/_schemas.py b/onnxscript/ir/_schemas.py index 806980f7dc..5a37fef34a 100644 --- a/onnxscript/ir/_schemas.py +++ b/onnxscript/ir/_schemas.py @@ -374,6 +374,16 @@ def __str__(self) -> str: ) return f"{domain}::{self.name}{overload}({params}) -> ({outputs}) where {type_constraints_str}" + @property + def inputs(self) -> Sequence[Parameter]: + """Returns the input parameters.""" + return [param for param in self.params if isinstance(param, Parameter)] + + @property + def attributes(self) -> Sequence[AttributeParameter]: + """Returns the attribute parameters.""" + return [param for param in self.params if isinstance(param, AttributeParameter)] + @classmethod def from_op_schema(cls, op_schema: onnx.defs.OpSchema) -> OpSignature: """Produce an OpSignature from an ONNX OpSchema.""" diff --git a/tests/function_libs/torch_lib/ops_test.py b/tests/function_libs/torch_lib/ops_test.py index a45050fb22..beb74b5462 100644 --- a/tests/function_libs/torch_lib/ops_test.py +++ b/tests/function_libs/torch_lib/ops_test.py @@ -109,17 +109,6 @@ def test_script_function_passes_checker( function_proto = torchlib_op_info.op.to_function_proto() onnx.checker.check_function(function_proto) # type: ignore[attr-defined] - @parameterized.parameterized.expand( - [(info.op_info_name, info) for info in ops_test_data.TESTED_TORCHLIB_OPS] - ) - def test_function_has_op_schema(self, _, torchlib_op_info: ops_test_data.TorchLibOpInfo): - func = torchlib_op_info.op - if not hasattr(func, "op_schema"): - raise AssertionError(f"Function {func.__name__} does not have op_schema attribute") - schema = func.op_schema - self.assertIsNotNone(schema) - self.assertEqual(schema.name, func.name) - def run_test_output_match( test_suite: unittest.TestCase, @@ -157,12 +146,12 @@ def run_test_output_match( onnx_function = torchlib_op_info.op input_wrangler = torchlib_op_info.input_wrangler if ( - not ops_test_common.dtype_op_schema_compatible(dtype, onnx_function.op_schema) + not ops_test_common.dtype_op_schema_compatible(dtype, onnx_function.op_signature) and dtype not in COMPLEX_TYPES ): test_suite.skipTest( f"dtype '{dtype}' is not supported by the op '{op.name}'. " - f"Type constraints: {onnx_function.op_schema.type_constraints}" + f"Type constraints: {onnx_function.op_signature.params}" ) # Obtain the tolerance for the op diff --git a/tests/function_libs/torch_lib/ops_test_common.py b/tests/function_libs/torch_lib/ops_test_common.py index 99594ee17e..05d94cf8a7 100644 --- a/tests/function_libs/torch_lib/ops_test_common.py +++ b/tests/function_libs/torch_lib/ops_test_common.py @@ -36,6 +36,7 @@ import onnxscript import onnxscript.evaluator from onnxscript import ir +from onnxscript.ir import _schemas from tests.function_libs.torch_lib import error_reproduction T = TypeVar("T") @@ -394,23 +395,23 @@ def _format_model_and_input_information(onnx_model, inputs): return f"Inputs:\n{pprint.pformat(inputs)}\nModel:\n{onnx.printer.to_text(onnx_model)}" -TORCH_DTYPE_TO_ONNX_STRING = { - torch.bool: "tensor(bool)", - torch.uint8: "tensor(uint8)", - torch.int8: "tensor(int8)", - torch.int16: "tensor(int16)", - torch.int32: "tensor(int32)", - torch.int64: "tensor(int64)", - torch.float16: "tensor(float16)", - torch.float32: "tensor(float)", - torch.float64: "tensor(double)", - torch.complex64: "tensor(complex64)", - torch.complex128: "tensor(complex128)", - torch.bfloat16: "tensor(bfloat16)", +_TORCH_DTYPE_TO_ONNX_TYPE = { + torch.bool: ir.DataType.BOOL, + torch.uint8: ir.DataType.UINT8, + torch.int8: ir.DataType.INT8, + torch.int16: ir.DataType.INT16, + torch.int32: ir.DataType.INT32, + torch.int64: ir.DataType.INT64, + torch.float16: ir.DataType.FLOAT16, + torch.float32: ir.DataType.FLOAT, + torch.float64: ir.DataType.DOUBLE, + torch.complex64: ir.DataType.COMPLEX64, + torch.complex128: ir.DataType.COMPLEX128, + torch.bfloat16: ir.DataType.BFLOAT16, } -def dtype_op_schema_compatible(dtype: torch.dtype, schema: onnx.defs.OpSchema) -> bool: +def dtype_op_schema_compatible(dtype: torch.dtype, schema: _schemas.OpSignature) -> bool: """Checks if the dtype is compatible with the schema. When a dtype is "compatible" with the schema, it means we can use the dtype @@ -418,12 +419,12 @@ def dtype_op_schema_compatible(dtype: torch.dtype, schema: onnx.defs.OpSchema) - Args: dtype: The torch dtype used to create sample inputs by OpInfo. - schema: The ONNX schema of the function. + schema: The OpSignature of the function. Returns: True if the dtype is compatible with the schema. """ - if not schema.inputs: + if not schema.params: # If there are no inputs, we can't check compatibility. Assume it is compatible. # e.g. aten_randn has only attributes. return True @@ -457,16 +458,12 @@ def dtype_op_schema_compatible(dtype: torch.dtype, schema: onnx.defs.OpSchema) - # 'tensor(bfloat16)']. # Since torch.float32 (tensor(float)) is in the allowed types, we return True. - first_input_type_name = schema.inputs[0].type_str - # Find the type constraint for the first input by matching the parameter name - first_input_type_constraint = next( - (x for x in schema.type_constraints if first_input_type_name in x.type_param_str), - None, - ) + first_input_type_constraint = schema.inputs[0].type_constraint assert first_input_type_constraint is not None - allowed_type_strs = first_input_type_constraint.allowed_type_strs + allowed_types = first_input_type_constraint.allowed_types # Here we consider seq(tensor(float)) compatible with tensor(float) as well - return any(TORCH_DTYPE_TO_ONNX_STRING[dtype] in type_str for type_str in allowed_type_strs) + allowed_dtypes = {type_.dtype for type_ in allowed_types} + return _TORCH_DTYPE_TO_ONNX_TYPE[dtype] in allowed_dtypes def graph_executor( From eef0934a6c7df40f1b477b773af96153255472be Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 22 Jan 2026 22:57:55 -0800 Subject: [PATCH 26/31] update Signed-off-by: Justin Chu --- onnxscript/_internal/autocast.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/onnxscript/_internal/autocast.py b/onnxscript/_internal/autocast.py index 8b14fa2932..f349286144 100644 --- a/onnxscript/_internal/autocast.py +++ b/onnxscript/_internal/autocast.py @@ -132,9 +132,7 @@ def cast_inputs( return tuple(cast(x, None) for x in args) # Filter to get only input parameters (not AttributeParameters) - expected_inputs = [ - param for param in op_signature.params if isinstance(param, _schemas.Parameter) - ] + expected_inputs = op_signature.inputs # We make two passes. In the first pass, we identify known type-bindings for # type-variables: eg., {'T1' : np.float32, 'T2' : np.int32}. # In the second pass, we use these bindings to cast scalar-values to From 4b58dac2c6510592c72fec179109fde6a13ec906 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 22 Jan 2026 22:58:47 -0800 Subject: [PATCH 27/31] Update onnxscript/_internal/converter.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxscript/_internal/converter.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/onnxscript/_internal/converter.py b/onnxscript/_internal/converter.py index 7c279c60e0..b72ae5c8f6 100644 --- a/onnxscript/_internal/converter.py +++ b/onnxscript/_internal/converter.py @@ -434,7 +434,14 @@ def _emit_const( suggested_name = "const" ovar = self.generate_unique_name(suggested_name) - attr = ir.AttrTensor("value", ir.tensor(pyvalue, name=ovar)) + try: + tensor = ir.tensor(pyvalue, name=ovar) + except Exception as exc: # noqa: BLE001 + self.fail( + info, + f"Failed to convert constant value {pyvalue!r} to ONNX tensor: {exc}", + ) + attr = ir.AttrTensor("value", tensor) self._castable.add(ovar) return self.emit1([ovar], values.Op(self.default_opset, "Constant"), [], [attr]) From 838d3c30550dc3921e4ab58f5c1be5af7eac257d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 22 Jan 2026 22:59:01 -0800 Subject: [PATCH 28/31] lint Signed-off-by: Justin Chu --- onnxscript/_internal/converter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/_internal/converter.py b/onnxscript/_internal/converter.py index b72ae5c8f6..df5fc8dfab 100644 --- a/onnxscript/_internal/converter.py +++ b/onnxscript/_internal/converter.py @@ -436,7 +436,7 @@ def _emit_const( try: tensor = ir.tensor(pyvalue, name=ovar) - except Exception as exc: # noqa: BLE001 + except Exception as exc: self.fail( info, f"Failed to convert constant value {pyvalue!r} to ONNX tensor: {exc}", From 8ef9c0b2c613e145ca42d8888bf05c67bb29ac3c Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 22 Jan 2026 23:00:15 -0800 Subject: [PATCH 29/31] update Signed-off-by: Justin Chu --- tests/function_libs/torch_lib/ops_test_common.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/function_libs/torch_lib/ops_test_common.py b/tests/function_libs/torch_lib/ops_test_common.py index 05d94cf8a7..f2deb7cd07 100644 --- a/tests/function_libs/torch_lib/ops_test_common.py +++ b/tests/function_libs/torch_lib/ops_test_common.py @@ -424,11 +424,12 @@ def dtype_op_schema_compatible(dtype: torch.dtype, schema: _schemas.OpSignature) Returns: True if the dtype is compatible with the schema. """ - if not schema.params: + inputs = schema.inputs + if not inputs: # If there are no inputs, we can't check compatibility. Assume it is compatible. # e.g. aten_randn has only attributes. return True - if schema.inputs[0].name not in {"self", "input"}: + if inputs[0].name not in {"self", "input"}: # If the name of the first input is not "self" or "input", # it is usually an input that is not of the same type as the output. # We assume support in this case. @@ -458,7 +459,7 @@ def dtype_op_schema_compatible(dtype: torch.dtype, schema: _schemas.OpSignature) # 'tensor(bfloat16)']. # Since torch.float32 (tensor(float)) is in the allowed types, we return True. - first_input_type_constraint = schema.inputs[0].type_constraint + first_input_type_constraint = inputs[0].type_constraint assert first_input_type_constraint is not None allowed_types = first_input_type_constraint.allowed_types # Here we consider seq(tensor(float)) compatible with tensor(float) as well From e4fb6291f23ecd7596c644212a846b43617fc247 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 22 Jan 2026 23:08:56 -0800 Subject: [PATCH 30/31] fail text Signed-off-by: Justin Chu --- onnxscript/_internal/converter.py | 4 ++-- onnxscript/_internal/values.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxscript/_internal/converter.py b/onnxscript/_internal/converter.py index df5fc8dfab..e3d6f44854 100644 --- a/onnxscript/_internal/converter.py +++ b/onnxscript/_internal/converter.py @@ -436,9 +436,9 @@ def _emit_const( try: tensor = ir.tensor(pyvalue, name=ovar) - except Exception as exc: + except Exception as exc: # pylint: disable=broad-exception-caught self.fail( - info, + info.ast_node, f"Failed to convert constant value {pyvalue!r} to ONNX tensor: {exc}", ) attr = ir.AttrTensor("value", tensor) diff --git a/onnxscript/_internal/values.py b/onnxscript/_internal/values.py index ca34c36f0f..4f54184363 100644 --- a/onnxscript/_internal/values.py +++ b/onnxscript/_internal/values.py @@ -127,7 +127,7 @@ def __getattr__(self, attr: str) -> Op: try: schema = onnx.defs.get_schema(attr, self.version, self.domain) return Op(self, attr, schema) - except Exception as exc: + except Exception as exc: # pylint: disable=broad-exception-caught raise AttributeError(f"Attribute {attr} not found.") from exc def add_function_def(self, fun): From cb1fe9464f26b1beefff519f6a8e22a849f3b0b1 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 26 Jan 2026 12:21:23 -0800 Subject: [PATCH 31/31] Remove clone Signed-off-by: Justin Chu --- docs/examples/06_plot_model_local_funs.py | 13 -------- onnxscript/_internal/values.py | 38 ++++++++--------------- 2 files changed, 13 insertions(+), 38 deletions(-) diff --git a/docs/examples/06_plot_model_local_funs.py b/docs/examples/06_plot_model_local_funs.py index fdb0e434bb..7b06b54754 100644 --- a/docs/examples/06_plot_model_local_funs.py +++ b/docs/examples/06_plot_model_local_funs.py @@ -47,16 +47,3 @@ def l2norm(x: FLOAT["N"], y: FLOAT["N"]) -> FLOAT[1]: # noqa: F821 model = l2norm.to_model_proto() print(onnx.printer.to_text(model)) - -# %% -# Let's now explicitly specify which functions to include. -# First, generate a model with no model-local functions: - -model = l2norm.to_model_proto(functions=[]) -print(onnx.printer.to_text(model)) - -# %% -# Now, generate a model with one model-local function: - -model = l2norm.to_model_proto(functions=[sum]) -print(onnx.printer.to_text(model)) diff --git a/onnxscript/_internal/values.py b/onnxscript/_internal/values.py index 4f54184363..5c559a5504 100644 --- a/onnxscript/_internal/values.py +++ b/onnxscript/_internal/values.py @@ -352,7 +352,6 @@ def to_model_proto(self, **kwargs): def _to_model_proto( self, - functions: Optional[Sequence[ir.Function | onnx.FunctionProto | OnnxFunction]] = None, io_types: Optional[ONNXType] = None, input_types: Optional[Sequence[ONNXType]] = None, output_types: Optional[Sequence[ONNXType]] = None, @@ -363,8 +362,6 @@ def _to_model_proto( """Converts this instance into a `onnx.ModelProto`. Args: - functions: A list of functions to include in the model. - By default, all functions called at least once are included. io_types: When specified, all the inputs/outputs of the model are set to be of this type. input_types: When specified, all the inputs of the model @@ -380,32 +377,14 @@ def _to_model_proto( Returns: An instance of :class:`onnx.ModelProto`. """ - if functions is None: - # Identify functions to include in the model - sub_functions = self.function_ir.get_called_functions() - ir_functions: list[ir.Function] = [ - func.function_ir for func in sub_functions.values() - ] - else: - ir_functions = [] - for func in functions: - if isinstance(func, ir.Function): - ir_functions.append(func) - elif isinstance(func, onnx.FunctionProto): - ir_functions.append(ir.serde.deserialize_function(func)) - elif isinstance(func, OnnxFunction): - ir_functions.append(func.function_ir) - else: - raise TypeError( - f"functions must be a sequence of " - f"ir.Function, onnx.FunctionProto, or OnnxFunction, " - f"not {type(func)!r}." - ) + # Identify functions to include in the model + sub_functions = self.function_ir.get_called_functions() + ir_functions: list[ir.Function] = [func.function_ir for func in sub_functions.values()] # Duplicate the graph to create the model main_graph = self.function_ir.graph.clone() # Determine opset imports - opset_imports = main_graph.opset_imports + opset_imports = main_graph.opset_imports.copy() for func in ir_functions: domain = func.domain @@ -435,6 +414,15 @@ def _to_model_proto( model_proto = ir.to_proto(model) + # Update opset imports + del model_proto.opset_import[:] + model_proto.opset_import.extend( + [ + onnx.OperatorSetIdProto(domain=domain, version=version) + for domain, version in opset_imports.items() + ] + ) + # Set additional type information if provided graph = model_proto.graph