diff --git a/docs/examples/06_plot_model_local_funs.py b/docs/examples/06_plot_model_local_funs.py index fdb0e434bb..7b06b54754 100644 --- a/docs/examples/06_plot_model_local_funs.py +++ b/docs/examples/06_plot_model_local_funs.py @@ -47,16 +47,3 @@ def l2norm(x: FLOAT["N"], y: FLOAT["N"]) -> FLOAT[1]: # noqa: F821 model = l2norm.to_model_proto() print(onnx.printer.to_text(model)) - -# %% -# Let's now explicitly specify which functions to include. -# First, generate a model with no model-local functions: - -model = l2norm.to_model_proto(functions=[]) -print(onnx.printer.to_text(model)) - -# %% -# Now, generate a model with one model-local function: - -model = l2norm.to_model_proto(functions=[sum]) -print(onnx.printer.to_text(model)) diff --git a/noxfile.py b/noxfile.py index 71d28f8edd..2bac0c4373 100644 --- a/noxfile.py +++ b/noxfile.py @@ -41,7 +41,7 @@ "packaging", "protobuf", ) -ONNX_IR = "onnx_ir==0.1.13" +ONNX_IR = "onnx_ir==0.1.15" ONNX_IR_MAIN = "git+https://github.com/onnx/ir-py.git@main#egg=onnx_ir" diff --git a/onnxscript/_internal/autocast.py b/onnxscript/_internal/autocast.py index bc3e16f79e..f349286144 100644 --- a/onnxscript/_internal/autocast.py +++ b/onnxscript/_internal/autocast.py @@ -7,10 +7,9 @@ import numpy as np import onnx -import onnx.helper # noqa: TID251 -from onnx.defs import OpSchema from onnxscript import ir, tensor +from onnxscript.ir import _schemas if TYPE_CHECKING: from onnxscript._internal import converter @@ -20,23 +19,15 @@ # python values into ONNX TensorProto, while the runtime converts python values into # ONNXScript runtime's value-representation (based on Tensor). - -# Utilities to convert a python value to TensorProto (for use by the script converter) - - -def pyvalue_to_onnx_tensor(tensor_name: str, pyvalue): - return ir.serde.serialize_tensor(ir.tensor(pyvalue, name=tensor_name)) - - _REPEATED_ATTRIBUTE_TYPES = frozenset( { - onnx.AttributeProto.FLOATS, - onnx.AttributeProto.INTS, - onnx.AttributeProto.STRINGS, - onnx.AttributeProto.TENSORS, - onnx.AttributeProto.GRAPHS, - onnx.AttributeProto.SPARSE_TENSORS, - onnx.AttributeProto.TYPE_PROTOS, + ir.AttributeType.FLOATS, + ir.AttributeType.INTS, + ir.AttributeType.STRINGS, + ir.AttributeType.TENSORS, + ir.AttributeType.GRAPHS, + ir.AttributeType.SPARSE_TENSORS, + ir.AttributeType.TYPE_PROTOS, } ) @@ -45,33 +36,28 @@ def pyvalue_to_onnx_attribute( key: str, value: Any, name_generator: Callable[[], str], - attr_type: onnx.AttributeProto.AttributeType | None = None, -) -> onnx.AttributeProto: + attr_type: ir.AttributeType | None = None, +) -> ir.Attr: """Helper function to create an ONNX AttributeProto. - This is a refinement of onnx.helper.make_attribute that works with ONNX Script - conventions for allowed types for attribute-values. In particular, it allows - * Empty lists as attribute values, provided the attribute type is specified + * Empty lists can be attribute values, provided the attribute type is specified and is a list type. * Scalar-values like 1.0 as well as lists like [1, -1] to be specified when the attribute type is TensorProto by automatically converting the value into a 0-D or 1-D tensor respectively. """ + # TODO(justinchuby): Remove this function and use onnx-ir directly. if isinstance(value, list) and not value: # Empty list value: if attr_type is None: raise ValueError("Attribute type must be specified for empty list value.") if attr_type not in _REPEATED_ATTRIBUTE_TYPES: raise ValueError("Empty list value is only allowed for repeated attribute types.") - return onnx.AttributeProto(name=key, type=attr_type) - elif attr_type == onnx.AttributeProto.TENSOR and not isinstance(value, onnx.TensorProto): - return onnx.AttributeProto( - name=key, type=attr_type, t=pyvalue_to_onnx_tensor(name_generator(), value) - ) + return ir.Attr(name=key, type=attr_type, value=[]) + elif attr_type == ir.AttributeType.TENSOR and not isinstance(value, onnx.TensorProto): + return ir.AttrTensor(name=key, value=ir.tensor(value, name=name_generator())) else: - # When the value is a subgraph, ONNX IR will complain that some values are - # not found from the scope. - return onnx.helper.make_attribute(key, value) # noqa: TID251 + return ir.convenience.convert_attribute(key, value, attr_type=attr_type) # Utilities to convert python values into onnxscript tensors. @@ -126,7 +112,7 @@ def cast_pyvalue_to_os_tensor(pyvalue, dtype=None): def cast_inputs( get_type_info: Callable[[Any], Any], cast: Callable[[Any, Any], Any], - op_schema: OpSchema | None, + op_signature: _schemas.OpSignature | None, args, ) -> tuple[Any, ...]: """Uses schema specification to support a limited form of auto-casting. @@ -140,12 +126,13 @@ def cast_inputs( This is used by the converter in a static-mode, as well as by the eager-mode execution in a dynamic-mode. """ - if op_schema is None: + if op_signature is None: # Either an error or a custom op. # No checks/casts in this case. return tuple(cast(x, None) for x in args) - expected_inputs = op_schema.inputs + # Filter to get only input parameters (not AttributeParameters) + expected_inputs = op_signature.inputs # We make two passes. In the first pass, we identify known type-bindings for # type-variables: eg., {'T1' : np.float32, 'T2' : np.int32}. # In the second pass, we use these bindings to cast scalar-values to @@ -156,9 +143,9 @@ def cast_inputs( for i, x in enumerate(args): if i < len(expected_inputs): expected = expected_inputs[i] - elif expected_inputs[-1].option == OpSchema.FormalParameterOption.Variadic: + elif expected_inputs[-1].variadic: expected = expected_inputs[-1] - if not expected.is_homogeneous: + if not expected.homogeneous: args_typevars.append((x, None)) continue else: @@ -166,7 +153,7 @@ def cast_inputs( f"Number of actual parameters {len(args)} " f"exceeds number of formal parameters {len(expected_inputs)}." ) - typevar = expected.type_str + typevar = expected.type_constraint.name if "(" not in typevar: # typevar is an identifier, like "T" typeinfo = get_type_info(x) @@ -177,18 +164,18 @@ def cast_inputs( return tuple(cast_args) -def dynamic_cast_inputs(op_schema: OpSchema, args): +def dynamic_cast_inputs(op_signature: _schemas.OpSignature, args): """Used for autocast during eager-mode execution.""" def get_type_info(x): return x.dtype if isinstance(x, tensor.Tensor) else None - return cast_inputs(get_type_info, cast_pyvalue_to_os_tensor, op_schema, args) + return cast_inputs(get_type_info, cast_pyvalue_to_os_tensor, op_signature, args) def static_cast_inputs( converter_: converter.Converter, - op_schema: Optional[OpSchema], + op_signature: Optional[_schemas.OpSignature], args: Sequence[Optional[ir.Value]], ) -> tuple[str, ...]: """Used for autocast during script-translation. @@ -212,4 +199,4 @@ def cast_like(x: Optional[ir.Value], y: Optional[ir.Value]) -> Optional[str]: return converter_.emit1([x_cast], "CastLike", [x, y]) return x - return cast_inputs(get_type_info, cast_like, op_schema, args) + return cast_inputs(get_type_info, cast_like, op_signature, args) diff --git a/onnxscript/_internal/converter.py b/onnxscript/_internal/converter.py index bf447047af..e3d6f44854 100644 --- a/onnxscript/_internal/converter.py +++ b/onnxscript/_internal/converter.py @@ -12,7 +12,6 @@ Union, ) -import onnx import onnx_ir as ir import onnxscript @@ -29,6 +28,7 @@ from onnxscript._internal import ( type_annotation as ta, ) +from onnxscript.ir import _schemas logger = logging.getLogger("onnxscript") @@ -273,25 +273,25 @@ def _enter_scope(self, name: str, parent_node: ast.AST): """Enter a control-flow block (a loop body or if-then-else branch). The block is translated into a nested-scope in ONNX. """ - self._outer.insert(0, self._current_fn) + self._outer.append(self._current_fn) self._current_fn = irbuilder.IRFunction(name) - self._locals.insert(0, {}) + self._locals.append({}) logger.debug("Converter:_enter_scope:%d:node:%s", len(self._locals), type(parent_node)) def _exit_scope(self) -> irbuilder.IRFunction: """Exit from a control-flow block (a loop body or if-then-else branch).""" logger.debug("Converter:_exit_scope:%d", len(self._locals)) graph = self._current_fn - self._current_fn = self._outer.pop(0) - self._locals.pop(0) + self._current_fn = self._outer.pop() + self._locals.pop() return graph def _current_scope(self) -> dict[str, LocalSymValue]: - return self._locals[0] + return self._locals[-1] def _bind(self, name: str, val: LocalSymValue) -> None: logger.debug("Converter:_bind:%s", name) - self._locals[0][name] = val + self._locals[-1][name] = val def _lookup( self, name: str, info: sourceinfo.SourceInfo, raise_exception: bool = True @@ -302,7 +302,7 @@ def _lookup( cases include: constant values or functions (mapped to Graph attributes), etc. """ - for scope in self._locals: + for scope in reversed(self._locals): if name in scope: return scope[name] if name in self.globals: @@ -320,21 +320,6 @@ def generate_unique_name(self, candidate: str = "tmp") -> str: self._used_vars.add(r) return r - def _make_onnx_attr( - self, attrname: str, attrval: Any, attrtype: int | None = None - ) -> ir.Attr: - if isinstance(attrval, ir.Graph): - return ir.Attr(attrname, ir.AttributeType.GRAPH, attrval) - - def tensor_name_generator() -> str: - """Return name to be used for tensor, if we need to create one.""" - return self.generate_unique_name(f"attr_{attrname}") - - proto = autocast.pyvalue_to_onnx_attribute( - attrname, attrval, tensor_name_generator, attrtype - ) - return ir.from_proto(proto) - def _to_onnx_attr_ref( self, val: values.AttrRef, info: sourceinfo.SourceInfo | None ) -> ir.Attr: @@ -371,7 +356,7 @@ def _to_onnx_var( # distinguish between int and bool. So we cast the int tensor to a bool tensor, # to promote a (python) bool attribute to a ONNX bool tensor. result_as_bool = self.generate_unique_name(result_name + "_as_bool") - cast_attr = self._make_onnx_attr("to", onnx_types.BOOL.dtype) + cast_attr = ir.AttrInt64("to", onnx_types.BOOL.dtype) self._castable.add(result_as_bool) return self.emit1( [result_as_bool], @@ -448,11 +433,15 @@ def _emit_const( else: suggested_name = "const" ovar = self.generate_unique_name(suggested_name) + try: - tensor = autocast.pyvalue_to_onnx_tensor(ovar, pyvalue) - except ValueError as e: - fail(info.msg(str(e))) - attr = self._make_onnx_attr("value", tensor) + tensor = ir.tensor(pyvalue, name=ovar) + except Exception as exc: # pylint: disable=broad-exception-caught + self.fail( + info.ast_node, + f"Failed to convert constant value {pyvalue!r} to ONNX tensor: {exc}", + ) + attr = ir.AttrTensor("value", tensor) self._castable.add(ovar) return self.emit1([ovar], values.Op(self.default_opset, "Constant"), [], [attr]) @@ -518,7 +507,7 @@ def _translate_attr( self, attr_name: str, expr: ast.AST, - attr_meta: onnx.defs.OpSchema.Attribute | None = None, + attr_meta: _schemas.AttributeParameter | None = None, ) -> ir.Attr | None: """Translate an attribute-value specification of the form `attr_name=` in a call to an op. expr is an AST. The following cases are supported: @@ -558,8 +547,8 @@ def _translate_attr( f"Outer scope variable '{pyvar}' referenced by function " f"'{expr.id!r}' modified.", ) - # Create GraphProto attribute - val = irfunction.to_graph_proto() + + val = irfunction.graph if isinstance(val, ir.Value): self.fail(expr, f"Cannot use ir.Value '{expr.id}' as an attribute.") else: @@ -581,13 +570,10 @@ def _translate_attr( if attr_meta and attr_meta.required: self.fail(expr, f"Attribute '{attr_name}' is required.") return None - attr_type = int(attr_meta.type) if attr_meta else None - attr = self._make_onnx_attr(attr_name, val, attrtype=attr_type) - if attr_meta and (attr.type != attr_meta.type): - self.fail( - expr, - f"Attribute type '{attr.type}' does not match expected type '{attr_meta.type}'", - ) + attr_type = attr_meta.type if attr_meta else None + if attr_type == ir.AttributeType.TENSOR: + val = ir.tensor(val) + attr = ir.convenience.convert_attribute(attr_name, val, attr_type) return attr def _translate_docstring(self, node: ast.Expr) -> None: @@ -805,7 +791,7 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[ir.Value, ir.Value, ir.Value steps.append(inputs[2]) if len(starts) > 1: - axis_0_attr = self._make_onnx_attr("axis", 0) + axis_0_attr = ir.AttrInt64("axis", 0) start_name = self.generate_unique_name(f"{var_name}_start") start_value = self.emit([start_name], "Concat", starts, [axis_0_attr]) @@ -855,7 +841,7 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[ir.Value, ir.Value, ir.Value last_axis = None for axis, index_expr in non_scalar_indices: index_value = self._translate_expr(index_expr) - axis_attr = self._make_onnx_attr("axis", axis) + axis_attr = ir.AttrInt64("axis", axis) # use Gather to perform indexing # Assign gathered value to either temporary or final target if axis != last_axis: # use temporary to store result of Gather @@ -880,14 +866,11 @@ def _translate_call_expr( op_signature, node.args, kwargs, fill_defaults=False ) args = [self._translate_opt_expr(x) for x in args] - attrs = [ - self._translate_attr(x, y, callee.op_schema.attributes[x]) - for x, y in attrs.items() - ] + attrs = [self._translate_attr(x, y, op_signature.get(x)) for x, y in attrs.items()] else: args = [self._translate_opt_expr(x) for x in node.args] attrs = [self._translate_attr(x.arg, x.value) for x in node.keywords] - args = autocast.static_cast_inputs(self, callee.op_schema, args) + args = autocast.static_cast_inputs(self, op_signature, args) # In ONNX, there is no way to explicitly specify a None value for an attribute. # Instead, the attribute must be omitted from the attribute list. @@ -896,27 +879,27 @@ def _translate_call_expr( return callee, args, attrs def _cast_like_binary_expression(self, op, left, right) -> tuple[ir.Value, ir.Value]: - schema = op.op_schema - return autocast.static_cast_inputs(self, schema, (left, right)) + op_signature = op.op_signature + return autocast.static_cast_inputs(self, op_signature, (left, right)) def _translate_binary_op_expr(self, node: ast.BinOp): op = type(node.op) if op not in primop_map: raise ValueError(self._message(node, f"Unsupported operator {op!r}.")) - attr = [] + attrs = [] if isinstance(node.op, ast.Mod) and self._is_constant_expr(node.right): # specific case X % f where f is a float. # attribute fmod=1 is added in that case. cst = self._eval_constant_expr(node.right) if isinstance(cst, float): - attr = [self._make_onnx_attr("fmod", 1)] + attrs = [ir.AttrInt64("fmod", 1)] op = values.Op(self.default_opset, primop_map[op]) left, right = self._cast_like_binary_expression( op, self._translate_expr(node.left), self._translate_expr(node.right) ) - return op, [left, right], attr + return op, [left, right], attrs def _translate_unary_op_expr(self, node): op = type(node.op) @@ -1159,9 +1142,9 @@ def _translate_if_stmt(self, stmt: ast.If) -> None: test = self._translate_expr(stmt.test, "cond") lineno = self._source_of(stmt).lineno thenGraph = self._translate_block(stmt.body, f"thenGraph_{lineno}", live_defs) - thenAttr = self._make_onnx_attr("then_branch", thenGraph) + thenAttr = ir.AttrGraph("then_branch", thenGraph) elseGraph = self._translate_block(stmt.orelse, f"elseGraph_{lineno}", live_defs) - elseAttr = self._make_onnx_attr("else_branch", elseGraph) + elseAttr = ir.AttrGraph("else_branch", elseGraph) def rename(x): return self.generate_unique_name(x) @@ -1338,7 +1321,7 @@ def _translate_loop_stmt(self, loop_stmt: ast.For | ast.While) -> None: inputs = [o_loop_bound, o_loop_condition] + [ self._py_var_to_onnx_var(pv, self._source_of(loop_stmt)) for pv in loop_state_vars ] - attrs = [self._make_onnx_attr("body", body.graph)] + attrs = [ir.AttrGraph("body", body.graph)] info = self._source_of(loop_stmt) def rename(x): @@ -1381,7 +1364,7 @@ def _translate_block( self._current_fn.outputs.append(output) else: python_var_value = None - for scope in self._locals: # TODO: skip _current_scope + for scope in reversed(self._locals): # TODO: skip _current_scope if python_var in scope: python_var_value = scope[python_var] break @@ -1435,8 +1418,8 @@ def _translate_function_signature_common( # The code can only be exported as a function. typeinfo = None if typeinfo and ta.is_attr_type(typeinfo): - attribute_type = ta.pytype_to_attrtype(typeinfo) - attr = ir.Attr(x.arg, ir.AttributeType(attribute_type), default_value, None) + attribute_type = _schemas.get_attr_type(typeinfo) + attr = ir.Attr(x.arg, attribute_type, default_value, None) self._current_fn.append_parameter(attr) as_bool = ta.base_type_is_bool(typeinfo) self._bind(x.arg, values.AttrRef(attr, as_bool, self._source_of(x))) diff --git a/onnxscript/_internal/converter_test.py b/onnxscript/_internal/converter_test.py index e825697d8e..c1fe276ef5 100644 --- a/onnxscript/_internal/converter_test.py +++ b/onnxscript/_internal/converter_test.py @@ -291,7 +291,7 @@ def test_renaming(self): self.validate_save(renaming, shape_inference=False) @pytest.mark.xfail( - strict=True, + strict=False, reason="optional output is not yet implemented", ) def test_opt_output(self): diff --git a/onnxscript/_internal/evaluator.py b/onnxscript/_internal/evaluator.py index 1415733397..9252996f99 100644 --- a/onnxscript/_internal/evaluator.py +++ b/onnxscript/_internal/evaluator.py @@ -21,6 +21,7 @@ import onnx import onnx.defs import onnx.reference +import onnx_ir as ir from typing_extensions import TypeAlias from onnxscript import onnx_opset, tensor @@ -123,20 +124,24 @@ def _unwrap_tensors_in_kwargs(kwargs: Mapping[str, Any]) -> dict[str, Any]: @runtime_checkable class Evaluator(Protocol): - """Protocol for evaluating ONNX ops.""" + """Protocol for evaluating ONNX ops. - def eval( + NOTE: The ``eval`` method was deprecated and removed. Implement ``eval_op`` + and ``eval_function`` instead. + """ + + def eval_op( self, - schema: onnx.defs.OpSchema, - inputs: Sequence[ExtendedModeValue], - attributes: Mapping[str, Any], + op: values.Op, + args: Sequence[ExtendedModeValue], + kwargs: Mapping[str, ExtendedModeValue], ): - """Evaluates an ONNX op. + """Evaluates an Op. Args: - schema: The OpSchema of the operator to evaluate. - inputs: The ONNX inputs to the op. - attributes: The ONNX attributes to the op. + op: The Op to evaluate. + args: The positional arguments to the op. + kwargs: The keyword arguments to the op. """ def eval_function( @@ -175,42 +180,25 @@ def __init__(self, ignore_unknown_function_kwargs: bool = False): """ self._ignore_unknown_function_kwargs = ignore_unknown_function_kwargs - def eval( - self, - schema: onnx.defs.OpSchema, - inputs: Sequence[ExtendedModeValue], - attributes: Mapping[str, Any], + def _adapt_inputs( + self, op_signature: _schemas.OpSignature, inputs: Sequence[ExtendedModeValue] ): - """Evaluates an ONNX op. - - Args: - schema: The OpSchema of the operator to evaluate. - inputs: The ONNX inputs to the op. - attributes: The ONNX attributes to the op. - """ - attributes = _unwrap_tensors_in_kwargs(attributes) - attributes, closure = self.adapt_attributes(schema, attributes) - inputs = self.adapt_inputs(schema, inputs) - outputs = self._eval(schema, inputs, attributes, closure) - return self.adapt_outputs(schema, outputs) - - def adapt_inputs(self, schema: onnx.defs.OpSchema, inputs: Sequence[ExtendedModeValue]): """Transform inputs to the expected format for the evaluator. Enables some syntactic sugar, such as the use of Python scalars, in a manner consistent with the translator. See autocast.py for details. """ - return autocast.dynamic_cast_inputs(schema, inputs) + return autocast.dynamic_cast_inputs(op_signature, inputs) - def adapt_attributes( - self, schema: onnx.defs.OpSchema, attributes: Mapping[str, ExtendedModeValue] + def _adapt_attributes( + self, op_signature, attributes: Mapping[str, ExtendedModeValue] ) -> tuple[dict[str, ExtendedModeValue], dict[str, ExtendedModeValue]]: """Transform attributes to the expected format for the evaluator. Returns: A closure that can be used to evaluate graph-valued attributes. """ - use_graph_attribute = self.use_graph_attribute(schema) + use_graph_attribute = self.use_graph_attribute(op_signature) closure: dict[Any, Any] = {} adapted_attributes = {} for k, v in attributes.items(): @@ -230,16 +218,15 @@ def adapt_attributes( adapted_attributes[k] = v return adapted_attributes, closure - def adapt_outputs(self, schema: onnx.defs.OpSchema, outputs: Sequence[EagerModeValue]): + def _adapt_outputs(self, outputs: Sequence[EagerModeValue]): """Adapt evaluator's output to convention used in onnxscript. Onnxscript uses a tuple/sequence only when number of outputs > 1. """ - del schema # unused return outputs[0] if len(outputs) == 1 else outputs - def use_graph_attribute(self, schema: onnx.defs.OpSchema): - del schema # unused + def use_graph_attribute(self, op_signature: _schemas.OpSignature) -> bool: + del op_signature # unused return True @abc.abstractmethod @@ -259,6 +246,20 @@ def _eval( closure: The closure to use when evaluating graph-valued attributes. """ + def eval_op( + self, + op: values.Op, + args: Sequence[ExtendedModeValue], + kwargs: Mapping[str, ExtendedModeValue], + ): + op_signature = op.op_signature + assert op_signature is not None, f"Op {op.name} has no signature." + attributes = _unwrap_tensors_in_kwargs(kwargs) + attributes, closure = self._adapt_attributes(op_signature, attributes) + inputs = self._adapt_inputs(op_signature, args) + outputs = self._eval(op.op_schema, inputs, attributes, closure) + return self._adapt_outputs(outputs) + def eval_function( self, function: values.OnnxFunction, @@ -275,6 +276,8 @@ def eval_function( kwargs: The keyword arguments to the function. """ op_signature = function.op_signature + if op_signature is None: + raise RuntimeError(f"Function {function.name} has no signature.") # Split happens in the evaluator instead of the OnnxFunction __call__ method # so that evaluators can control behaviors like whether to fill in default values for attributes. tagged_args, tagged_kwargs = param_manipulation.tag_arguments_with_signature( @@ -416,12 +419,15 @@ def _prepare_model_and_inputs_for_eager( implicit_args = {k: _onnxscript_to_numpy_value(v) for k, v in implicit_args.items()} # Utility to convert kwarg to ONNX AttributeProto: + # TODO(justinchuby): Clean up this function to use onnx-ir def make_attr(key: str, value: Any) -> onnx.AttributeProto: def make_tensor_name() -> str: return f"attr_{key}" - return autocast.pyvalue_to_onnx_attribute( - key, value, make_tensor_name, int(schema.attributes[key].type) + return ir.to_proto( + autocast.pyvalue_to_onnx_attribute( + key, value, make_tensor_name, int(schema.attributes[key].type) + ) ) # Construct ONNX model with a single op call: @@ -504,8 +510,14 @@ def _call_ort( return [_numpy_to_onnxscript_value(x) for x in result] -def _schema_id(schema: onnx.defs.OpSchema) -> tuple[str, str, int]: - return schema.name, schema.domain, schema.since_version +def _op_identifier( + op_schema_or_signature: onnx.defs.OpSchema | _schemas.OpSignature, +) -> tuple[str, str, int]: + return ( + op_schema_or_signature.name, + op_schema_or_signature.domain, + op_schema_or_signature.since_version, + ) class ORTEvaluator(BaseEvaluator): @@ -552,13 +564,13 @@ def __init__(self) -> None: super().__init__() self._python_ops: dict[tuple[str, str, int], Any] = {} - def use_graph_attribute(self, schema: onnx.defs.OpSchema) -> bool: - return _schema_id(schema) not in self._python_ops + def use_graph_attribute(self, op_signature: _schemas.OpSignature) -> bool: + return _op_identifier(op_signature) not in self._python_ops def _eval(self, schema, inputs, attributes, closure): - schemaid = _schema_id(schema) - if schemaid in self._python_ops: - return self._python_ops[schemaid](inputs, attributes) + identifier = _op_identifier(schema) + if identifier in self._python_ops: + return self._python_ops[identifier](inputs, attributes) else: return super()._eval(schema, inputs, attributes, closure) @@ -566,8 +578,8 @@ def register(self, opset: values.Opset) -> Callable[[_T], _T]: assert opset is not None def decorator(function: _T) -> _T: - schema = opset[function.__name__] - self._python_ops[_schema_id(schema)] = function + op_signature = opset[function.__name__].op_signature + self._python_ops[_op_identifier(op_signature)] = function return function return decorator diff --git a/onnxscript/_internal/evaluator_test.py b/onnxscript/_internal/evaluator_test.py index c696ddf9b4..4949c04675 100644 --- a/onnxscript/_internal/evaluator_test.py +++ b/onnxscript/_internal/evaluator_test.py @@ -31,11 +31,13 @@ def square(y: FLOAT["N"]) -> FLOAT["N"]: # noqa: F821 np.testing.assert_equal(output, expected) # Test using ort-mixed-evaluator - output = seq_map[evaluator.ort_mixed_evaluator](x) + with evaluator.default_as(evaluator.ort_mixed_evaluator): + output = seq_map(x) np.testing.assert_equal(output, expected) # Test using ort-evaluator - output = seq_map[evaluator.ort_evaluator](x) + with evaluator.default_as(evaluator.ort_evaluator): + output = seq_map(x) np.testing.assert_equal(output, expected) diff --git a/onnxscript/_internal/irbuilder.py b/onnxscript/_internal/irbuilder.py index e5fa80622e..1ae3c7bdb1 100644 --- a/onnxscript/_internal/irbuilder.py +++ b/onnxscript/_internal/irbuilder.py @@ -77,7 +77,7 @@ def append_parameter(self, parameter: ir.Value | ir.Attr) -> None: def add_nested_function(self, fun: IRFunction) -> None: self.nested_functions[fun.name] = fun - def get_called_functions(self) -> dict[str, onnx.FunctionProto]: + def get_called_functions(self) -> dict[str, values.OnnxFunction]: called_functions: dict[str, values.OnnxFunction] = {} def visit(function_ir: IRFunction): @@ -94,12 +94,12 @@ def add(f: values.OnnxFunction): visit(self) - return {name: f.to_function_proto() for name, f in called_functions.items()} + return called_functions def to_graph_proto(self) -> onnx.GraphProto: """Converts this instance into a `onnx.GraphProto`.""" - return ir.to_proto(self.graph) + return ir.serde.serialize_graph(self.graph) def to_function_proto(self) -> onnx.FunctionProto: """Converts this instance into a `onnx.FunctionProto`.""" - return ir.to_proto(self) + return ir.serde.serialize_function(self) diff --git a/onnxscript/_internal/type_annotation.py b/onnxscript/_internal/type_annotation.py index fb7b8a370d..a5509f92b6 100644 --- a/onnxscript/_internal/type_annotation.py +++ b/onnxscript/_internal/type_annotation.py @@ -22,22 +22,8 @@ # - Above types with annotation attached TypeAnnotationValue = typing.Any -# Map from python type to corresponding ONNX AttributeProto type -_PYTYPE_TO_ATTRTYPE_MAP = { - float: onnx.AttributeProto.FLOAT, - int: onnx.AttributeProto.INT, - str: onnx.AttributeProto.STRING, - bool: onnx.AttributeProto.INT, # experimental -} -# Map from python type to corresponding ONNX AttributeProto type, -# for repeated (i.e., list of) values -_LISTTYPE_TO_ATTRTYPE_MAP = { - float: onnx.AttributeProto.FLOATS, - int: onnx.AttributeProto.INTS, - str: onnx.AttributeProto.STRINGS, - bool: onnx.AttributeProto.INTS, # experimental -} +_PRIMITIVE_ATTR_TYPE = frozenset((float, int, str, bool)) _LIST_CONSTRUCTORS = frozenset([list, typing.List, typing.Sequence, collections.abc.Sequence]) @@ -92,33 +78,13 @@ def _remove_annotation(typeinfo: TypeAnnotationValue) -> TypeAnnotationValue: def _is_primitive_attr_type(typeinfo: TypeAnnotationValue) -> bool: - return typeinfo in _PYTYPE_TO_ATTRTYPE_MAP - - -def pytype_to_attrtype( - pytype: TypeAnnotationValue, -) -> Optional[onnx.AttributeProto.AttributeType]: - pytype = _remove_annotation(pytype) - if pytype in _PYTYPE_TO_ATTRTYPE_MAP: - return _PYTYPE_TO_ATTRTYPE_MAP[pytype] - type_constructor = typing.get_origin(pytype) - # Remove Optional wrapper if present, which is represented as an Union[..., type(None)] - if type_constructor is typing.Union: - # Filter out type(None), since typing.Optional[X] evaluates to Union[X, type(None)] - args = [x for x in typing.get_args(pytype) if x is not type(None)] - if len(args) == 1: - return pytype_to_attrtype(args[0]) - if type_constructor in _LIST_CONSTRUCTORS: - elt_type = typing.get_args(pytype)[0] - if elt_type in _LISTTYPE_TO_ATTRTYPE_MAP: - return _LISTTYPE_TO_ATTRTYPE_MAP[elt_type] - return None + return typeinfo in _PRIMITIVE_ATTR_TYPE def base_type_is_bool(pytype: TypeAnnotationValue) -> bool: """Returns True if base type of pytype is bool, False otherwise.""" pytype = _remove_annotation(pytype) - if pytype in _PYTYPE_TO_ATTRTYPE_MAP: + if pytype in _PRIMITIVE_ATTR_TYPE: return pytype is bool type_constructor = typing.get_origin(pytype) if type_constructor in _LIST_CONSTRUCTORS: @@ -215,82 +181,3 @@ def get_return_types(typeinfo: type | Sequence[type]) -> Sequence[type]: if typing.get_origin(typeinfo) is tuple: return typing.get_args(typeinfo) return (typeinfo,) - - -def pytype_to_type_strings(pytype: TypeAnnotationValue) -> list[str]: - """Returns a list of type-strings corresponding to a given type annotation. - - Args: - pytype: A type annotation. - - Returns: - A list of all supported input types for the given type annotation. - Ensures that the list is sorted in the same order as ALL_TYPE_STRINGS. - """ - if pytype is None: - return list(ALL_TENSOR_TYPE_STRINGS) - if pytype is onnx_types.TensorType: - return list(ALL_TENSOR_TYPE_STRINGS) - if isinstance(pytype, type) and issubclass(pytype, onnx_types.TensorType): - return [pytype.to_string()] - if isinstance(pytype, onnx_types.TensorType): - return [pytype.to_string()] - if isinstance(pytype, typing.TypeVar): - constraints = pytype.__constraints__ - if constraints: - return pytype_to_type_strings(Union.__getitem__(constraints)) # pylint: disable=unnecessary-dunder-call - bound = pytype.__bound__ - if bound is None: - return list(ALL_TENSOR_TYPE_STRINGS) - return pytype_to_type_strings(bound) - if typing.get_origin(pytype) is Union: - options = [] - subtypes = typing.get_args(pytype) - # A None type in a Union is equivalent to an optional type - optional = is_optional(pytype) - for subtype in subtypes: - if subtype is type(None): - # Skip None type because we are handling it with is_optional - continue - if optional: - options += [ - *pytype_to_type_strings(subtype), - *[f"optional({s})" for s in pytype_to_type_strings(subtype)], - ] - else: - options += pytype_to_type_strings(subtype) - # Remove duplicates - return sorted(set(options)) - if typing.get_origin(pytype) in _LIST_CONSTRUCTORS: - subtypes = typing.get_args(pytype) - return [f"seq({s})" for s in pytype_to_type_strings(subtypes[0])] - - raise ValueError(f"Unsupported type: {pytype}") - - -def get_type_constraint_name(pytype: TypeAnnotationValue) -> Optional[str]: - """Returns the name of the type constraint for a given type annotation. - - Args: - pytype: A type annotation. - - Returns: - The name of the type constraint if it is a TypeVar. - - Prefixes the name with "Optional_" if the type annotation is Optional[TypeVar]. - - Prefixes the name with "Sequence_" if the type annotation is a Sequence[]. - - Returns None if the type annotation does not have a type constraint. - """ - if isinstance(pytype, typing.TypeVar): - return pytype.__name__ - if is_optional(pytype): - subtypes = typing.get_args(pytype) - for subtype in subtypes: - if subtype is type(None): - continue - type_param_name = get_type_constraint_name(subtype) - return f"Optional_{type_param_name}" if type_param_name else None - if typing.get_origin(pytype) in _LIST_CONSTRUCTORS: - subtypes = typing.get_args(pytype) - type_param_name = get_type_constraint_name(subtypes[0]) - return f"Sequence_{type_param_name}" if type_param_name else None - return None diff --git a/onnxscript/_internal/type_annotation_test.py b/onnxscript/_internal/type_annotation_test.py index 259157e66d..622a04339a 100644 --- a/onnxscript/_internal/type_annotation_test.py +++ b/onnxscript/_internal/type_annotation_test.py @@ -2,14 +2,9 @@ # Licensed under the MIT License. import unittest -from typing import Any, List, Optional, Sequence, TypeVar, Union -import parameterized - -import onnxscript import onnxscript.testing -from onnxscript import FLOAT, INT64, script -from onnxscript._internal import type_annotation +from onnxscript import FLOAT, script from onnxscript.onnx_opset import opset15 as op from tests.common import testutils @@ -90,157 +85,5 @@ def bool_type_for_attribute(self: FLOAT[...], sorted: bool) -> FLOAT[...]: ) -_TestTypeVarConstraints = TypeVar("_TestTypeVarConstraints", INT64, FLOAT) -_TestTypeVarOneBound = TypeVar("_TestTypeVarOneBound", bound=INT64) -_TestTypeVarTwoBound = TypeVar("_TestTypeVarTwoBound", bound=Union[INT64, FLOAT]) - - -class TypeConversionFunctionsTest(unittest.TestCase): - @parameterized.parameterized.expand( - [ - ( - "tensor_type_all", - onnxscript.onnx_types.TensorType, - list(type_annotation.ALL_TENSOR_TYPE_STRINGS), - ), - ("none", None, list(type_annotation.ALL_TENSOR_TYPE_STRINGS)), - ("tensor_type", INT64, ["tensor(int64)"]), - ("tensor_type_union", Union[INT64, FLOAT], ["tensor(float)", "tensor(int64)"]), - ("tensor_type_variadic_shape", INT64[...], ["tensor(int64)"]), - ("tensor_type_shape", INT64[10], ["tensor(int64)"]), - ( - "type_var_constraints", - _TestTypeVarConstraints, - ["tensor(float)", "tensor(int64)"], - ), - ("type_bound_one", _TestTypeVarOneBound, ["tensor(int64)"]), - ("type_bound_two", _TestTypeVarTwoBound, ["tensor(float)", "tensor(int64)"]), - ( - "optional_tensor_type_all", - Optional[onnxscript.onnx_types.TensorType], - [ - *[ - f"optional({tensor_type})" - for tensor_type in type_annotation.ALL_TENSOR_TYPE_STRINGS - ], - *type_annotation.ALL_TENSOR_TYPE_STRINGS, - ], - ), - ( - "optional_tensor_type", - Optional[INT64], - ["optional(tensor(int64))", "tensor(int64)"], - ), - ( - "optional_tensor_type_union", - Optional[Union[INT64, FLOAT]], - [ - "optional(tensor(float))", - "optional(tensor(int64))", - "tensor(float)", - "tensor(int64)", - ], - ), - ( - "optional_tensor_type_variadic_shape", - Optional[INT64[...]], - ["optional(tensor(int64))", "tensor(int64)"], - ), - ( - "optional_tensor_type_shape", - Optional[INT64[10]], - ["optional(tensor(int64))", "tensor(int64)"], - ), - ( - "optional_type_var_constraints", - Optional[_TestTypeVarConstraints], - [ - "optional(tensor(float))", - "optional(tensor(int64))", - "tensor(float)", - "tensor(int64)", - ], - ), - ( - "optional_type_bound_one", - Optional[_TestTypeVarOneBound], - ["optional(tensor(int64))", "tensor(int64)"], - ), - ( - "optional_type_bound_two", - Optional[_TestTypeVarTwoBound], - [ - "optional(tensor(float))", - "optional(tensor(int64))", - "tensor(float)", - "tensor(int64)", - ], - ), - ( - "sequence_type_all", - Sequence[onnxscript.onnx_types.TensorType], - [ - f"seq({tensor_type})" - for tensor_type in type_annotation.ALL_TENSOR_TYPE_STRINGS - ], - ), - ("sequence_type", Sequence[INT64], ["seq(tensor(int64))"]), - ( - "union_sequence_type", - Union[Sequence[INT64], Sequence[FLOAT]], - ["seq(tensor(float))", "seq(tensor(int64))"], - ), - ( - "sequence_type_variadic_shape", - Sequence[INT64[...]], - ["seq(tensor(int64))"], - ), - ("sequence_type_shape", Sequence[INT64[10]], ["seq(tensor(int64))"]), - ( - "sequence_type_var_constraints", - Sequence[_TestTypeVarConstraints], - ["seq(tensor(float))", "seq(tensor(int64))"], - ), - ( - "sequence_type_bound_one", - Sequence[_TestTypeVarOneBound], - ["seq(tensor(int64))"], - ), - ( - "sequence_type_bound_two", - Sequence[_TestTypeVarTwoBound], - ["seq(tensor(float))", "seq(tensor(int64))"], - ), - ] - ) - def test_pytype_to_type_strings(self, _, pytype: Any, expected: List[str]): - self.assertEqual(type_annotation.pytype_to_type_strings(pytype), expected) - - @parameterized.parameterized.expand( - [ - ("type_var", _TestTypeVarConstraints, "_TestTypeVarConstraints"), - ("type_var_bound", _TestTypeVarOneBound, "_TestTypeVarOneBound"), - ( - "optional_type_var", - Optional[_TestTypeVarOneBound], - "Optional__TestTypeVarOneBound", - ), - ( - "sequence_type_var", - Sequence[_TestTypeVarOneBound], - "Sequence__TestTypeVarOneBound", - ), - ("normal_type", INT64, None), - ("union_type", Union[INT64, FLOAT], None), - ("optional_type", Optional[INT64], None), - ("sequence_type", Sequence[INT64], None), - ("optional_sequence_type", Optional[Sequence[INT64]], None), - ("optional_union_type", Optional[Union[INT64, FLOAT]], None), - ] - ) - def test_get_type_constraint_name(self, _: str, pytype: Any, expected: Optional[str]): - self.assertEqual(type_annotation.get_type_constraint_name(pytype), expected) - - if __name__ == "__main__": unittest.main() diff --git a/onnxscript/_internal/values.py b/onnxscript/_internal/values.py index 2f22e1eefa..5c559a5504 100644 --- a/onnxscript/_internal/values.py +++ b/onnxscript/_internal/values.py @@ -11,7 +11,7 @@ import logging import types import typing -from typing import ( # type: ignore[attr-defined] +from typing import ( Any, Callable, ClassVar, @@ -27,7 +27,7 @@ import onnx_ir as ir from typing_extensions import ParamSpec -from onnxscript._internal import ast_utils, deprecation, irbuilder, sourceinfo, type_annotation +from onnxscript._internal import ast_utils, irbuilder, sourceinfo from onnxscript._internal import converter as converter_module from onnxscript.ir import _schemas from onnxscript.onnx_types import ONNXType @@ -107,7 +107,8 @@ def __repr__(self): def __getitem__(self, opname): try: - return onnx.defs.get_schema(opname, self.version, self.domain) + schema = onnx.defs.get_schema(opname, self.version, self.domain) + return Op(self, opname, schema) except Exception: # pylint: disable=broad-except # TODO: more specific exception return None @@ -122,11 +123,11 @@ def __contains__(self, opname): def __str__(self) -> str: return self.domain - def __getattr__(self, attr: str): + def __getattr__(self, attr: str) -> Op: try: schema = onnx.defs.get_schema(attr, self.version, self.domain) return Op(self, attr, schema) - except Exception as exc: + except Exception as exc: # pylint: disable=broad-exception-caught raise AttributeError(f"Attribute {attr} not found.") from exc def add_function_def(self, fun): @@ -168,9 +169,6 @@ def name(self) -> str: ... @property def opset(self) -> Opset: ... - @property - def op_schema(self) -> Optional[onnx.defs.OpSchema]: ... - @property def op_signature(self) -> Optional[_schemas.OpSignature]: ... @@ -191,7 +189,13 @@ def __init__( ) -> None: self._opset = opset self._name = name - self._op_schema = op_schema or opset[name] + self._op_schema: onnx.defs.OpSchema | None + if op_schema is not None: + self._op_schema = op_schema + elif (op := opset[name]) is not None: + self._op_schema = op.op_schema + else: + self._op_schema = None self._signature: Optional[_schemas.OpSignature] = None if self._op_schema is None: @@ -203,15 +207,16 @@ def __init__( ) def __call__(self, *args, **kwargs): - # FIXME(after #225): Move import to the top of the file. from onnxscript._internal import evaluator # pylint: disable=import-outside-toplevel - schema = self.op_schema - if schema is None: - raise RuntimeError( - f"Op '{self.name}' does not have an OpSchema and cannot be evaluated." - ) - return evaluator.default().eval(schema, args, kwargs) + default_evaluator = evaluator.default() + if hasattr(default_evaluator, "eval"): + # Interface prior to onnxscript 0.6, used by PyTorch 2.10 and older + if self.op_schema is None: + raise ValueError(f"OpSchema not found for op '{self.name}'.") + return default_evaluator.eval(self.op_schema, args, kwargs) + # Use the new interface + return evaluator.default().eval_op(self, args, kwargs) @property def name(self) -> str: @@ -225,10 +230,6 @@ def opset(self) -> Opset: def op_schema(self) -> Optional[onnx.defs.OpSchema]: return self._op_schema - def has_schema(self) -> bool: - """Returns True if this op has an OpSchema.""" - return self.op_schema is not None - @property def op_signature(self) -> Optional[_schemas.OpSignature]: """Returns the signature of this op.""" @@ -261,99 +262,6 @@ class OnnxClosure: function: Any -@dataclasses.dataclass -class TypeConstraint: - """Represents a type constraint for an ONNX op. - - Attributes: - name: The name of the type constraint. - allowed_types: The allowed types for the type constraint. - """ - - name: str - allowed_types: list[str] - description: str = "" - - def as_tuple(self) -> tuple[str, list[str], str]: - """Returns the type constraint as a tuple.""" - return (self.name, self.allowed_types, self.description) - - -def _op_schema_from_function_ir( - function_ir: irbuilder.IRFunction, opset: Opset -) -> onnx.defs.OpSchema: - """Construct an ONNX OpSchema from an IRFunction.""" - - # Find all distinct types in the inputs and outputs - distinct_types = {_typeinfo(arg) for arg in function_ir.inputs}.union( - {_typeinfo(arg) for arg in function_ir.outputs} - ) - # Create a mapping from type to a unique name - type_to_constraint = {} - for i, type_ in enumerate(distinct_types): - name = f"T{i}" - type_to_constraint[type_] = TypeConstraint( - name=type_annotation.get_type_constraint_name(type_) or name, - allowed_types=type_annotation.pytype_to_type_strings(type_), - ) - - formal_inputs = [ - onnx.defs.OpSchema.FormalParameter( - arg.name, - type_to_constraint[_typeinfo(arg)].name, - param_option=( - onnx.defs.OpSchema.FormalParameterOption.Optional - if type_annotation.is_optional(_typeinfo(arg)) - else onnx.defs.OpSchema.FormalParameterOption.Single - ), - # TODO(justinchu): Check this is_homogeneous thing - is_homogeneous=True, - ) - for arg in function_ir.inputs - ] - formal_outputs = [ - onnx.defs.OpSchema.FormalParameter( - arg.name, - type_to_constraint[_typeinfo(arg)].name, - param_option=( - onnx.defs.OpSchema.FormalParameterOption.Optional - if type_annotation.is_optional(_typeinfo(arg)) - else onnx.defs.OpSchema.FormalParameterOption.Single - ), - # TODO(justinchu): Check this is_homogeneous thing - is_homogeneous=True, - ) - for arg in function_ir.outputs - ] - return onnx.defs.OpSchema( - function_ir.name, - opset.domain, - since_version=opset.version, - doc=function_ir.doc_string or "", - inputs=formal_inputs, - outputs=formal_outputs, - type_constraints=[constraint.as_tuple() for constraint in type_to_constraint.values()], - attributes=[ - *[ - onnx.defs.OpSchema.Attribute( - attr.name, - type=onnx.defs.OpSchema.AttrType(attr.type), # type: ignore[call-arg] - ) - for attr in function_ir.attrs - if attr.value is None - ], - *[ - onnx.defs.OpSchema.Attribute( - attr.name, - default_value=ir.to_proto(attr), - ) - for attr in function_ir.attrs - if attr.value is not None - ], - ], - ) - - class OnnxFunction(Op, Generic[_P, _R]): """Represents an ONNX op for which a function-body has been defined in onnxscript. @@ -389,9 +297,13 @@ def __init__( super().__init__(opset, irfun.name) self.function = pyfun self.function_ir = irfun + # Record the function domain's opset version in the function metadata + self.function_ir.meta["opset_version"] = opset.version self.source = source self.kwargs = kwargs - self._op_schema: Optional[onnx.defs.OpSchema] = None + self._signature = _schemas.OpSignature.from_function( + self.function, domain=self.function_ir.domain, name=self.name + ) # Allow the object to be inspected as a function functools.update_wrapper(self, pyfun) @@ -399,67 +311,17 @@ def __init__( # Experimental fields self.traceable = False - @property - @deprecation.deprecated( - since="0.1", - removed_in="the future", - instructions="use '.name' instead", - ) - def opname(self) -> str: - # NOTE: This is a temporary alias for backward compatibility with PyTorch 2.0. - # TODO: Remove this in onnxscript 0.3. - return self.name - - @property - def op_schema(self) -> Optional[onnx.defs.OpSchema]: - """Construct an OpSchema from function_ir.""" - if self._op_schema is not None: - return self._op_schema - - self._op_schema = _op_schema_from_function_ir(self.function_ir, self.opset) - - return self._op_schema - @property def op_signature(self) -> Optional[_schemas.OpSignature]: """Returns the signature of this op.""" - if self._signature is not None: - return self._signature - - if self.op_schema is None: - return None - - self._signature = _schemas.OpSignature.from_function( - self.function, domain=self.function_ir.domain, name=self.name - ) return self._signature @op_signature.setter def op_signature(self, value: _schemas.OpSignature): self._signature = value - def __getitem__(self, instance): - """Returns a lambda to evaluate function using given evaluator instance. - - Usage: - script_fun(X) executes the function using the default evaluator instance. - script_fun[instance](X) executes the function using the given evaluator instance. - """ - - def fun(*args, **kwargs): - # FIXME(after #225): Move import to the top of the file. - from onnxscript._internal import ( # pylint: disable=import-outside-toplevel - evaluator, - ) - - with evaluator.default_as(instance): - return self.__call__(*args, **kwargs) - - return fun - def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: """Implements an eager-mode execution of an onnxscript function.""" - # FIXME(after #225): Move import to the top of the file. from onnxscript._internal import evaluator # pylint: disable=import-outside-toplevel return evaluator.default().eval_function(self, args, kwargs) # type: ignore[arg-type, return-value] @@ -490,7 +352,6 @@ def to_model_proto(self, **kwargs): def _to_model_proto( self, - functions=None, io_types: Optional[ONNXType] = None, input_types: Optional[Sequence[ONNXType]] = None, output_types: Optional[Sequence[ONNXType]] = None, @@ -501,8 +362,6 @@ def _to_model_proto( """Converts this instance into a `onnx.ModelProto`. Args: - functions: A list of functions to include in the model. - By default, all functions called at least once are included. io_types: When specified, all the inputs/outputs of the model are set to be of this type. input_types: When specified, all the inputs of the model @@ -519,35 +378,26 @@ def _to_model_proto( An instance of :class:`onnx.ModelProto`. """ # Identify functions to include in the model - if functions is None: - sub_functions = self.function_ir.get_called_functions() - functions = sub_functions.values() - else: - - def to_proto(f): - if isinstance(f, onnx.FunctionProto): - return f - if isinstance(f, OnnxFunction): - return f.to_function_proto() - raise TypeError("Expected a value of type FunctionProto of OnnxFunction") - - functions = [to_proto(f) for f in functions] + sub_functions = self.function_ir.get_called_functions() + ir_functions: list[ir.Function] = [func.function_ir for func in sub_functions.values()] + # Duplicate the graph to create the model + main_graph = self.function_ir.graph.clone() # Determine opset imports - opsets = self.function_ir.graph.opset_imports + opset_imports = main_graph.opset_imports.copy() - for proto in functions: - if proto.domain not in opsets: - opsets[proto.domain] = 1 - # TODO(rama): Handle conflicts with appropriate error/warning message. - for opset in proto.opset_import: - if opset.domain not in opsets: - opsets[opset.domain] = opset.version + for func in ir_functions: + domain = func.domain + if domain is not None and domain not in opset_imports: + opset_imports[domain] = func.meta.get("opset_version", 1) - if "" not in opsets: + if "" not in opset_imports and "" in func.opset_imports: + opset_imports[""] = func.opset_imports[""] + + if "" not in opset_imports: # No operator is using the standard opset. # Use the specified version if provided or the default value. - opsets[""] = ( + opset_imports[""] = ( opset_version if opset_version is not None else onnx.defs.onnx_opset_version() ) @@ -555,12 +405,23 @@ def to_proto(f): if "ir_version" in kwargs: ir_version = kwargs.pop("ir_version") else: - ir_version = select_ir_version(opsets[""]) + ir_version = select_ir_version(opset_imports[""]) # Create the model - model = ir.Model(self.function_ir.graph, ir_version=ir_version) + model = ir.Model(main_graph, ir_version=ir_version) + for func in ir_functions: + model.functions[func.identifier()] = func + model_proto = ir.to_proto(model) - model_proto.functions.extend(functions) + + # Update opset imports + del model_proto.opset_import[:] + model_proto.opset_import.extend( + [ + onnx.OperatorSetIdProto(domain=domain, version=version) + for domain, version in opset_imports.items() + ] + ) # Set additional type information if provided graph = model_proto.graph @@ -604,6 +465,9 @@ class TracedOnnxFunction(Op): def __init__(self, opset: Opset, func: Callable): super().__init__(opset, func.__name__) self.func = func + self._signature = _schemas.OpSignature.from_function( + self.func, domain="_traced", name=self.name + ) # Allow the object to be inspected as a function functools.update_wrapper(self, func) @@ -633,30 +497,9 @@ def function_ir(self) -> irbuilder.IRFunction: return converter.translate_function_signature(func_ast) - @property - def op_schema(self) -> Optional[onnx.defs.OpSchema]: - """Return the OpSchema.""" - - if self._op_schema is not None: - return self._op_schema - - # FIXME(justinchuby): outputs are empty. Need to fix. - self._op_schema = _op_schema_from_function_ir(self.function_ir, self._opset) - - return self._op_schema - @property def op_signature(self) -> Optional[_schemas.OpSignature]: """Returns the signature of this op.""" - if self._signature is not None: - return self._signature - - if self.op_schema is None: - return None - - self._signature = _schemas.OpSignature.from_function( - self.func, domain="_traced", name=self.name - ) return self._signature @op_signature.setter diff --git a/onnxscript/ir/_schemas.py b/onnxscript/ir/_schemas.py index d4d88ab5bb..5a37fef34a 100644 --- a/onnxscript/ir/_schemas.py +++ b/onnxscript/ir/_schemas.py @@ -106,8 +106,10 @@ class Parameter: type_constraint: TypeConstraintParam required: bool variadic: bool + homogeneous: bool = True + min_arity: int = 1 + # TODO: Add differentiation_category default: Any = _EMPTY_DEFAULT - # TODO: Add other properties too def __str__(self) -> str: type_str = self.type_constraint.name @@ -188,6 +190,8 @@ def _convert_formal_parameter( type_constraint=type_constraint, required=param.option != onnx.defs.OpSchema.FormalParameterOption.Optional, variadic=param.option == onnx.defs.OpSchema.FormalParameterOption.Variadic, + homogeneous=param.is_homogeneous, + min_arity=param.min_arity, ) @@ -210,7 +214,7 @@ def _is_optional(type_: type) -> bool: return False -def _get_attr_type(type_: type) -> ir.AttributeType: +def get_attr_type(type_: type) -> ir.AttributeType: """Obtain the type of the attribute from a Python class.""" try: if type_ in _PY_TYPE_TO_ATTR_TYPE: @@ -339,6 +343,7 @@ class OpSignature: params_map: Mapping[str, Parameter | AttributeParameter] = dataclasses.field( init=False, repr=False ) + since_version: int = 1 def __post_init__(self): self.params_map = {param.name: param for param in self.params} @@ -369,6 +374,16 @@ def __str__(self) -> str: ) return f"{domain}::{self.name}{overload}({params}) -> ({outputs}) where {type_constraints_str}" + @property + def inputs(self) -> Sequence[Parameter]: + """Returns the input parameters.""" + return [param for param in self.params if isinstance(param, Parameter)] + + @property + def attributes(self) -> Sequence[AttributeParameter]: + """Returns the attribute parameters.""" + return [param for param in self.params if isinstance(param, AttributeParameter)] + @classmethod def from_op_schema(cls, op_schema: onnx.defs.OpSchema) -> OpSignature: """Produce an OpSignature from an ONNX OpSchema.""" @@ -415,11 +430,17 @@ def from_op_schema(cls, op_schema: onnx.defs.OpSchema) -> OpSignature: overload="", params=params, outputs=outputs, + since_version=op_schema.since_version, ) @classmethod def from_function( - cls, func, domain: str, name: str | None = None, overload: str = "" + cls, + func, + domain: str, + name: str | None = None, + overload: str = "", + since_version: int = 1, ) -> OpSignature: """Produce an OpSignature from a function using type annotation.""" @@ -434,7 +455,7 @@ def from_function( for param in py_signature.parameters.values(): if param.name not in type_hints: - logger.warning( + logger.debug( "Missing annotation for parameter '%s' from %s. Treating as an Input.", param.name, py_signature, @@ -448,6 +469,7 @@ def from_function( required=param.default is inspect.Parameter.empty, # TODO: Handle variadic variadic=False, + homogeneous=True, default=param.default if param.default is not inspect.Parameter.empty else _EMPTY_DEFAULT, @@ -455,7 +477,7 @@ def from_function( ) else: type_ = type_hints[param.name] - if (attr_type := _get_attr_type(type_)) != ir.AttributeType.UNDEFINED: + if (attr_type := get_attr_type(type_)) != ir.AttributeType.UNDEFINED: # Construct the default attribute if param.default is not inspect.Parameter.empty: # TODO: Use ir_convenience instead to handle int as float @@ -498,6 +520,7 @@ def from_function( required=param.default is inspect.Parameter.empty, # TODO: Handle variadic variadic=False, + homogeneous=True, default=param.default if param.default is not inspect.Parameter.empty else _EMPTY_DEFAULT, @@ -535,6 +558,7 @@ def from_function( type_constraint=type_constraint, required=True, variadic=False, + homogeneous=True, default=_EMPTY_DEFAULT, ) ) @@ -545,4 +569,5 @@ def from_function( overload=overload, params=params, outputs=outputs, + since_version=since_version, ) diff --git a/onnxscript/tensor.py b/onnxscript/tensor.py index f1d781b808..6ad8f6bf12 100644 --- a/onnxscript/tensor.py +++ b/onnxscript/tensor.py @@ -16,6 +16,8 @@ class Tensor: Serves to define overloaded ops with an ONNX/ONNXScript semantics. """ + # TODO(justinchuby): Remove the tensor class and use ir.Value instead + def __init__(self, nparray: Optional[np.ndarray], opset=None): if nparray is not None and not isinstance(nparray, np.ndarray): raise TypeError( diff --git a/pyproject.toml b/pyproject.toml index 37cdf3d4ea..1f9025fb9a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ classifiers = [ dependencies = [ "ml_dtypes", "numpy", - "onnx_ir>=0.1.13,<2", # Expect onnx_ir to have a breaking change in 2.0. If not, extend this range. + "onnx_ir>=0.1.15,<2", # Expect onnx_ir to have a breaking change in 2.0. If not, extend this range. "onnx>=1.17", "packaging", "typing_extensions>=4.10", diff --git a/tests/function_libs/torch_lib/ops_test.py b/tests/function_libs/torch_lib/ops_test.py index a45050fb22..beb74b5462 100644 --- a/tests/function_libs/torch_lib/ops_test.py +++ b/tests/function_libs/torch_lib/ops_test.py @@ -109,17 +109,6 @@ def test_script_function_passes_checker( function_proto = torchlib_op_info.op.to_function_proto() onnx.checker.check_function(function_proto) # type: ignore[attr-defined] - @parameterized.parameterized.expand( - [(info.op_info_name, info) for info in ops_test_data.TESTED_TORCHLIB_OPS] - ) - def test_function_has_op_schema(self, _, torchlib_op_info: ops_test_data.TorchLibOpInfo): - func = torchlib_op_info.op - if not hasattr(func, "op_schema"): - raise AssertionError(f"Function {func.__name__} does not have op_schema attribute") - schema = func.op_schema - self.assertIsNotNone(schema) - self.assertEqual(schema.name, func.name) - def run_test_output_match( test_suite: unittest.TestCase, @@ -157,12 +146,12 @@ def run_test_output_match( onnx_function = torchlib_op_info.op input_wrangler = torchlib_op_info.input_wrangler if ( - not ops_test_common.dtype_op_schema_compatible(dtype, onnx_function.op_schema) + not ops_test_common.dtype_op_schema_compatible(dtype, onnx_function.op_signature) and dtype not in COMPLEX_TYPES ): test_suite.skipTest( f"dtype '{dtype}' is not supported by the op '{op.name}'. " - f"Type constraints: {onnx_function.op_schema.type_constraints}" + f"Type constraints: {onnx_function.op_signature.params}" ) # Obtain the tolerance for the op diff --git a/tests/function_libs/torch_lib/ops_test_common.py b/tests/function_libs/torch_lib/ops_test_common.py index 99594ee17e..f2deb7cd07 100644 --- a/tests/function_libs/torch_lib/ops_test_common.py +++ b/tests/function_libs/torch_lib/ops_test_common.py @@ -36,6 +36,7 @@ import onnxscript import onnxscript.evaluator from onnxscript import ir +from onnxscript.ir import _schemas from tests.function_libs.torch_lib import error_reproduction T = TypeVar("T") @@ -394,23 +395,23 @@ def _format_model_and_input_information(onnx_model, inputs): return f"Inputs:\n{pprint.pformat(inputs)}\nModel:\n{onnx.printer.to_text(onnx_model)}" -TORCH_DTYPE_TO_ONNX_STRING = { - torch.bool: "tensor(bool)", - torch.uint8: "tensor(uint8)", - torch.int8: "tensor(int8)", - torch.int16: "tensor(int16)", - torch.int32: "tensor(int32)", - torch.int64: "tensor(int64)", - torch.float16: "tensor(float16)", - torch.float32: "tensor(float)", - torch.float64: "tensor(double)", - torch.complex64: "tensor(complex64)", - torch.complex128: "tensor(complex128)", - torch.bfloat16: "tensor(bfloat16)", +_TORCH_DTYPE_TO_ONNX_TYPE = { + torch.bool: ir.DataType.BOOL, + torch.uint8: ir.DataType.UINT8, + torch.int8: ir.DataType.INT8, + torch.int16: ir.DataType.INT16, + torch.int32: ir.DataType.INT32, + torch.int64: ir.DataType.INT64, + torch.float16: ir.DataType.FLOAT16, + torch.float32: ir.DataType.FLOAT, + torch.float64: ir.DataType.DOUBLE, + torch.complex64: ir.DataType.COMPLEX64, + torch.complex128: ir.DataType.COMPLEX128, + torch.bfloat16: ir.DataType.BFLOAT16, } -def dtype_op_schema_compatible(dtype: torch.dtype, schema: onnx.defs.OpSchema) -> bool: +def dtype_op_schema_compatible(dtype: torch.dtype, schema: _schemas.OpSignature) -> bool: """Checks if the dtype is compatible with the schema. When a dtype is "compatible" with the schema, it means we can use the dtype @@ -418,16 +419,17 @@ def dtype_op_schema_compatible(dtype: torch.dtype, schema: onnx.defs.OpSchema) - Args: dtype: The torch dtype used to create sample inputs by OpInfo. - schema: The ONNX schema of the function. + schema: The OpSignature of the function. Returns: True if the dtype is compatible with the schema. """ - if not schema.inputs: + inputs = schema.inputs + if not inputs: # If there are no inputs, we can't check compatibility. Assume it is compatible. # e.g. aten_randn has only attributes. return True - if schema.inputs[0].name not in {"self", "input"}: + if inputs[0].name not in {"self", "input"}: # If the name of the first input is not "self" or "input", # it is usually an input that is not of the same type as the output. # We assume support in this case. @@ -457,16 +459,12 @@ def dtype_op_schema_compatible(dtype: torch.dtype, schema: onnx.defs.OpSchema) - # 'tensor(bfloat16)']. # Since torch.float32 (tensor(float)) is in the allowed types, we return True. - first_input_type_name = schema.inputs[0].type_str - # Find the type constraint for the first input by matching the parameter name - first_input_type_constraint = next( - (x for x in schema.type_constraints if first_input_type_name in x.type_param_str), - None, - ) + first_input_type_constraint = inputs[0].type_constraint assert first_input_type_constraint is not None - allowed_type_strs = first_input_type_constraint.allowed_type_strs + allowed_types = first_input_type_constraint.allowed_types # Here we consider seq(tensor(float)) compatible with tensor(float) as well - return any(TORCH_DTYPE_TO_ONNX_STRING[dtype] in type_str for type_str in allowed_type_strs) + allowed_dtypes = {type_.dtype for type_ in allowed_types} + return _TORCH_DTYPE_TO_ONNX_TYPE[dtype] in allowed_dtypes def graph_executor(