diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 447dbc56b0..04db72f102 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -1215,14 +1215,15 @@ def _replace_execute_engine_for_executorch(exp_program: Any) -> Any: f"'{engine_node.target}' not found on graph module" ) elif engine_node.op == "placeholder": - constants = getattr(exp_program, "constants", {}) - engine_obj = constants.get(engine_node.name) or constants.get( - engine_node.target - ) + from torch_tensorrt.dynamo._exporter import _resolve_lifted_custom_obj + + engine_obj = _resolve_lifted_custom_obj(exp_program, engine_node) if engine_obj is None: raise RuntimeError( f"execute_engine node '{node.name}': placeholder engine " - f"'{engine_node.name}' not found in exp_program.constants" + f"'{engine_node.name}' did not resolve to a lifted " + f"custom-object constant (available: " + f"{sorted(getattr(exp_program, 'constants', {}) or {})})" ) else: raise RuntimeError( diff --git a/py/torch_tensorrt/dynamo/_exporter.py b/py/torch_tensorrt/dynamo/_exporter.py index d844b8d92c..ab4d3e3d49 100644 --- a/py/torch_tensorrt/dynamo/_exporter.py +++ b/py/torch_tensorrt/dynamo/_exporter.py @@ -6,6 +6,7 @@ import torch from torch._export.non_strict_utils import make_constraints from torch._guards import detect_fake_mode +from torch._library.fake_class_registry import FakeScriptObject from torch._subclasses.fake_tensor import FakeTensor from torch.export import ExportedProgram, ExportGraphSignature from torch.export._trace import _combine_args @@ -23,6 +24,37 @@ from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ENGINE_IDX, NAME_IDX +def _resolve_lifted_custom_obj( + exp_program: ExportedProgram, node: torch.fx.Node +) -> Any: + # torch.export lifts custom objects into exp_program.constants keyed by their + # graph-signature FQN and renames the placeholder node, so constants[node.name] + # misses. Resolve name -> FQN through the signature mapping; the direct + # name/target lookup is only for legacy programs that carry no such mapping. + constants = getattr(exp_program, "constants", {}) or {} + sig = getattr(exp_program, "graph_signature", None) + name_to_fqn = ( + getattr(sig, "inputs_to_lifted_custom_objs", {}) or {} + if sig is not None + else {} + ) + + obj = None + fqn = name_to_fqn.get(node.name) + if fqn is not None: + obj = constants.get(fqn) + elif not name_to_fqn: + for key in (node.target, node.name): + if key in constants: + obj = constants[key] + break + + # A FakeScriptObject has no __getstate__; callers need the real object. + if isinstance(obj, FakeScriptObject): + obj = obj.real_obj + return obj + + def export( gm: torch.fx.GraphModule, *, diff --git a/py/torch_tensorrt/executorch/backend.py b/py/torch_tensorrt/executorch/backend.py index 03c7236afa..c4b29af24c 100644 --- a/py/torch_tensorrt/executorch/backend.py +++ b/py/torch_tensorrt/executorch/backend.py @@ -10,6 +10,7 @@ PreprocessResult, ) from torch.export.exported_program import ExportedProgram +from torch_tensorrt.dynamo._exporter import _resolve_lifted_custom_obj from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( DEVICE_IDX, ENGINE_IDX, @@ -135,14 +136,13 @@ def _get_engine_info_from_edge_program(edge_program: ExportedProgram) -> List[An f"'{engine_node.target}' not found on graph module" ) elif engine_node.op == "placeholder": - constants = getattr(edge_program, "constants", {}) - engine_obj = constants.get(engine_node.name) or constants.get( - engine_node.target - ) + engine_obj = _resolve_lifted_custom_obj(edge_program, engine_node) if engine_obj is None: raise RuntimeError( f"execute_engine node '{node.name}': placeholder engine " - f"'{engine_node.name}' not found in edge_program.constants" + f"'{engine_node.name}' did not resolve to a lifted custom-object " + f"constant (available: " + f"{sorted(getattr(edge_program, 'constants', {}) or {})})" ) else: raise RuntimeError( diff --git a/tests/py/dynamo/executorch/test_api.py b/tests/py/dynamo/executorch/test_api.py index 1fff565173..74fe5a5f7e 100644 --- a/tests/py/dynamo/executorch/test_api.py +++ b/tests/py/dynamo/executorch/test_api.py @@ -1,10 +1,14 @@ import ast import importlib import sys +import types from pathlib import Path import pytest import torch +from torch._library.fake_class_registry import FakeScriptObject + +from torch_tensorrt.dynamo._exporter import _resolve_lifted_custom_obj @pytest.mark.unit @@ -144,3 +148,57 @@ def test_executorch_headers_are_not_dlfw_gated(): isinstance(node, ast.Name) and node.id == "IS_DLFW_CI" for node in ast.walk(header_package_data) ) + + +def _stub_node(name, target=None): + return types.SimpleNamespace(name=name, target=name if target is None else target) + + +def _stub_exported_program(constants, name_to_fqn=None): + sig = ( + None + if name_to_fqn is None + else types.SimpleNamespace(inputs_to_lifted_custom_objs=name_to_fqn) + ) + return types.SimpleNamespace(constants=constants, graph_signature=sig) + + +@pytest.mark.unit +def test_resolve_lifted_custom_obj_via_signature_fqn(): + # Modern torch.export: placeholder name differs from the constants FQN key. + sentinel = object() + ep = _stub_exported_program({"engine_fqn": sentinel}, {"obj_engine": "engine_fqn"}) + assert _resolve_lifted_custom_obj(ep, _stub_node("obj_engine")) is sentinel + + +@pytest.mark.unit +def test_resolve_lifted_custom_obj_legacy_fallback(): + # No signature mapping: fall back to a direct name/target lookup. + sentinel = object() + ep = _stub_exported_program({"engine": sentinel}, name_to_fqn=None) + assert _resolve_lifted_custom_obj(ep, _stub_node("engine")) is sentinel + + +@pytest.mark.unit +def test_resolve_lifted_custom_obj_signature_present_name_absent_is_none(): + # A present-but-incomplete mapping must not bind a different object by name. + ep = _stub_exported_program({"engine": object()}, {"some_other_obj": "x"}) + assert _resolve_lifted_custom_obj(ep, _stub_node("engine")) is None + + +@pytest.mark.unit +def test_resolve_lifted_custom_obj_missing_is_none(): + ep = _stub_exported_program({}, name_to_fqn=None) + assert _resolve_lifted_custom_obj(ep, _stub_node("missing")) is None + + +@pytest.mark.unit +def test_resolve_lifted_custom_obj_unwraps_fake_script_object(): + class _Real: + pass + + fake = FakeScriptObject(object(), "Engine", _Real()) + ep = _stub_exported_program({"engine_fqn": fake}, {"obj_engine": "engine_fqn"}) + resolved = _resolve_lifted_custom_obj(ep, _stub_node("obj_engine")) + assert not isinstance(resolved, FakeScriptObject) + assert isinstance(resolved, _Real)