Skip to content

fix: resolve lifted TRT engine custom-object by graph-signature FQN in ExecuTorch export#4349

Open
shoumikhin wants to merge 1 commit into
pytorch:mainfrom
shoumikhin:fix/executorch-lifted-custom-obj
Open

fix: resolve lifted TRT engine custom-object by graph-signature FQN in ExecuTorch export#4349
shoumikhin wants to merge 1 commit into
pytorch:mainfrom
shoumikhin:fix/executorch-lifted-custom-obj

Conversation

@shoumikhin

Copy link
Copy Markdown
Contributor

What's broken

Saving a Torch-TensorRT model with output_format="executorch" fails when the model was compiled with require_full_compilation=False (so some ops stay outside TensorRT) and exported through the standard torch.export path:

RuntimeError: execute_engine node 'execute_engine': placeholder engine
'obj__run_on_acc_0_engine' not found in exp_program.constants

The engine is actually present in exp_program.constantssave() just looks it up by the wrong key.

Why it happens

When torch.export lifts the TensorRT engine (a ScriptObject), it stores it in exp_program.constants keyed by its graph-signature FQN (InputSpec.target) but gives the placeholder node a mangled name (an obj_ prefix). The current code looks it up by the node name/target:

engine_obj = constants.get(engine_node.name) or constants.get(engine_node.target)

Those don't equal the FQN key, so the lookup returns None and the export aborts. The legacy exporter (retrace=False) happened to keep the placeholder name equal to the constants key, so it worked by accident; the modern exporter is the one that breaks.

The fix

Resolve the placeholder via the canonical ExportGraphSignature.inputs_to_lifted_custom_objs mapping (placeholder name -> FQN), then read constants[fqn]. Keep the direct name/target lookup only as a fallback for legacy programs that lack that mapping, and unwrap a FakeScriptObject to its real object so it can be serialized. The same buggy lookup existed in two places (the save serializer in _compile.py and the backend engine-info extractor in executorch/backend.py); both now share one helper in dynamo/_exporter.py.

Repro

import torch
import torch_tensorrt

class M(torch.nn.Module):
    def forward(self, x):
        a = (x + 1.0) * 2.0       # TensorRT
        b = torch.cos(a)          # forced off TensorRT
        return (b * 3.0) + 4.0    # TensorRT

ep = torch.export.export(M().eval().cuda(), (torch.randn(2, 3, 4, 4).cuda(),))
gm = torch_tensorrt.dynamo.compile(
    ep, min_block_size=1, torch_executed_ops={torch.ops.aten.cos.default}
)
torch_tensorrt.save(
    gm,
    "m.pte",
    output_format="executorch",
    arg_inputs=(torch.randn(2, 3, 4, 4).cuda(),),
)

Before: raises the error above. After: writes a .pte containing both the TensorRT and CUDA delegates.

Tests

Adds CPU-only unit tests for the resolver in tests/py/dynamo/executorch/test_api.py (no GPU or ExecuTorch runtime required): signature-FQN resolution, legacy fallback, present-but-incomplete mapping -> None, missing -> None, and FakeScriptObject unwrap.

Saving a partially-TRT-compiled program to ExecuTorch
(output_format="executorch") via the modern torch.export path (retrace=True)
aborts with:

    RuntimeError: execute_engine node 'execute_engine': placeholder engine
    'obj__run_on_acc_0_engine' not found in exp_program.constants

even though the engine is present. torch.export lifts the TRT engine
ScriptObject as a custom-object constant keyed by its graph-signature FQN
(InputSpec.target) and renames the placeholder node (an obj_ prefix), so the
existing constants[node.name] / constants[node.target] lookup misses. The
legacy exporter (retrace=False) only worked by accident: it kept the
placeholder name equal to the constants key.

Resolve the placeholder via the canonical
ExportGraphSignature.inputs_to_lifted_custom_objs mapping, falling back to the
direct lookup only for legacy programs that lack it, and unwrap a
FakeScriptObject to its real object. A shared helper in dynamo/_exporter.py is
used by both the save serializer (_compile.py) and the backend engine-info
extractor (executorch/backend.py), which carried the same latent lookup.

Adds CPU-only unit tests for the resolver (no GPU/executorch required).

This unblocks coalescing TensorRT + CUDA delegates into one .pte via the
modern exporter.
@meta-cla meta-cla Bot added the cla signed label Jun 18, 2026
@github-actions github-actions Bot added component: tests Issues re: Tests component: core Issues re: The core compiler component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Jun 18, 2026
@github-actions github-actions Bot requested a review from lanluo-nvidia June 18, 2026 01:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed component: api [Python] Issues re: Python API component: core Issues re: The core compiler component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant