Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,14 +1215,15 @@ def _replace_execute_engine_for_executorch(exp_program: Any) -> Any:
f"'{engine_node.target}' not found on graph module"
)
elif engine_node.op == "placeholder":
constants = getattr(exp_program, "constants", {})
engine_obj = constants.get(engine_node.name) or constants.get(
engine_node.target
)
from torch_tensorrt.dynamo._exporter import _resolve_lifted_custom_obj

engine_obj = _resolve_lifted_custom_obj(exp_program, engine_node)
if engine_obj is None:
raise RuntimeError(
f"execute_engine node '{node.name}': placeholder engine "
f"'{engine_node.name}' not found in exp_program.constants"
f"'{engine_node.name}' did not resolve to a lifted "
f"custom-object constant (available: "
f"{sorted(getattr(exp_program, 'constants', {}) or {})})"
)
else:
raise RuntimeError(
Expand Down
32 changes: 32 additions & 0 deletions py/torch_tensorrt/dynamo/_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from torch._export.non_strict_utils import make_constraints
from torch._guards import detect_fake_mode
from torch._library.fake_class_registry import FakeScriptObject
from torch._subclasses.fake_tensor import FakeTensor
from torch.export import ExportedProgram, ExportGraphSignature
from torch.export._trace import _combine_args
Expand All @@ -23,6 +24,37 @@
from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ENGINE_IDX, NAME_IDX


def _resolve_lifted_custom_obj(
exp_program: ExportedProgram, node: torch.fx.Node
) -> Any:
# torch.export lifts custom objects into exp_program.constants keyed by their
# graph-signature FQN and renames the placeholder node, so constants[node.name]
# misses. Resolve name -> FQN through the signature mapping; the direct
# name/target lookup is only for legacy programs that carry no such mapping.
constants = getattr(exp_program, "constants", {}) or {}
sig = getattr(exp_program, "graph_signature", None)
name_to_fqn = (
getattr(sig, "inputs_to_lifted_custom_objs", {}) or {}
if sig is not None
else {}
)

obj = None
fqn = name_to_fqn.get(node.name)
if fqn is not None:
obj = constants.get(fqn)
elif not name_to_fqn:
for key in (node.target, node.name):
if key in constants:
obj = constants[key]
break

# A FakeScriptObject has no __getstate__; callers need the real object.
if isinstance(obj, FakeScriptObject):
obj = obj.real_obj
return obj


def export(
gm: torch.fx.GraphModule,
*,
Expand Down
10 changes: 5 additions & 5 deletions py/torch_tensorrt/executorch/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
PreprocessResult,
)
from torch.export.exported_program import ExportedProgram
from torch_tensorrt.dynamo._exporter import _resolve_lifted_custom_obj
from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import (
DEVICE_IDX,
ENGINE_IDX,
Expand Down Expand Up @@ -135,14 +136,13 @@ def _get_engine_info_from_edge_program(edge_program: ExportedProgram) -> List[An
f"'{engine_node.target}' not found on graph module"
)
elif engine_node.op == "placeholder":
constants = getattr(edge_program, "constants", {})
engine_obj = constants.get(engine_node.name) or constants.get(
engine_node.target
)
engine_obj = _resolve_lifted_custom_obj(edge_program, engine_node)
if engine_obj is None:
raise RuntimeError(
f"execute_engine node '{node.name}': placeholder engine "
f"'{engine_node.name}' not found in edge_program.constants"
f"'{engine_node.name}' did not resolve to a lifted custom-object "
f"constant (available: "
f"{sorted(getattr(edge_program, 'constants', {}) or {})})"
)
else:
raise RuntimeError(
Expand Down
58 changes: 58 additions & 0 deletions tests/py/dynamo/executorch/test_api.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import ast
import importlib
import sys
import types
from pathlib import Path

import pytest
import torch
from torch._library.fake_class_registry import FakeScriptObject

from torch_tensorrt.dynamo._exporter import _resolve_lifted_custom_obj


@pytest.mark.unit
Expand Down Expand Up @@ -144,3 +148,57 @@ def test_executorch_headers_are_not_dlfw_gated():
isinstance(node, ast.Name) and node.id == "IS_DLFW_CI"
for node in ast.walk(header_package_data)
)


def _stub_node(name, target=None):
return types.SimpleNamespace(name=name, target=name if target is None else target)


def _stub_exported_program(constants, name_to_fqn=None):
sig = (
None
if name_to_fqn is None
else types.SimpleNamespace(inputs_to_lifted_custom_objs=name_to_fqn)
)
return types.SimpleNamespace(constants=constants, graph_signature=sig)


@pytest.mark.unit
def test_resolve_lifted_custom_obj_via_signature_fqn():
# Modern torch.export: placeholder name differs from the constants FQN key.
sentinel = object()
ep = _stub_exported_program({"engine_fqn": sentinel}, {"obj_engine": "engine_fqn"})
assert _resolve_lifted_custom_obj(ep, _stub_node("obj_engine")) is sentinel


@pytest.mark.unit
def test_resolve_lifted_custom_obj_legacy_fallback():
# No signature mapping: fall back to a direct name/target lookup.
sentinel = object()
ep = _stub_exported_program({"engine": sentinel}, name_to_fqn=None)
assert _resolve_lifted_custom_obj(ep, _stub_node("engine")) is sentinel


@pytest.mark.unit
def test_resolve_lifted_custom_obj_signature_present_name_absent_is_none():
# A present-but-incomplete mapping must not bind a different object by name.
ep = _stub_exported_program({"engine": object()}, {"some_other_obj": "x"})
assert _resolve_lifted_custom_obj(ep, _stub_node("engine")) is None


@pytest.mark.unit
def test_resolve_lifted_custom_obj_missing_is_none():
ep = _stub_exported_program({}, name_to_fqn=None)
assert _resolve_lifted_custom_obj(ep, _stub_node("missing")) is None


@pytest.mark.unit
def test_resolve_lifted_custom_obj_unwraps_fake_script_object():
class _Real:
pass

fake = FakeScriptObject(object(), "Engine", _Real())
ep = _stub_exported_program({"engine_fqn": fake}, {"obj_engine": "engine_fqn"})
resolved = _resolve_lifted_custom_obj(ep, _stub_node("obj_engine"))
assert not isinstance(resolved, FakeScriptObject)
assert isinstance(resolved, _Real)
Loading