fix: resolve lifted TRT engine custom-object by graph-signature FQN in ExecuTorch export#4349
Open
shoumikhin wants to merge 1 commit into
Open
fix: resolve lifted TRT engine custom-object by graph-signature FQN in ExecuTorch export#4349shoumikhin wants to merge 1 commit into
shoumikhin wants to merge 1 commit into
Conversation
Saving a partially-TRT-compiled program to ExecuTorch
(output_format="executorch") via the modern torch.export path (retrace=True)
aborts with:
RuntimeError: execute_engine node 'execute_engine': placeholder engine
'obj__run_on_acc_0_engine' not found in exp_program.constants
even though the engine is present. torch.export lifts the TRT engine
ScriptObject as a custom-object constant keyed by its graph-signature FQN
(InputSpec.target) and renames the placeholder node (an obj_ prefix), so the
existing constants[node.name] / constants[node.target] lookup misses. The
legacy exporter (retrace=False) only worked by accident: it kept the
placeholder name equal to the constants key.
Resolve the placeholder via the canonical
ExportGraphSignature.inputs_to_lifted_custom_objs mapping, falling back to the
direct lookup only for legacy programs that lack it, and unwrap a
FakeScriptObject to its real object. A shared helper in dynamo/_exporter.py is
used by both the save serializer (_compile.py) and the backend engine-info
extractor (executorch/backend.py), which carried the same latent lookup.
Adds CPU-only unit tests for the resolver (no GPU/executorch required).
This unblocks coalescing TensorRT + CUDA delegates into one .pte via the
modern exporter.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What's broken
Saving a Torch-TensorRT model with
output_format="executorch"fails when the model was compiled withrequire_full_compilation=False(so some ops stay outside TensorRT) and exported through the standardtorch.exportpath:The engine is actually present in
exp_program.constants—save()just looks it up by the wrong key.Why it happens
When
torch.exportlifts the TensorRT engine (aScriptObject), it stores it inexp_program.constantskeyed by its graph-signature FQN (InputSpec.target) but gives the placeholder node a mangled name (anobj_prefix). The current code looks it up by the node name/target:Those don't equal the FQN key, so the lookup returns
Noneand the export aborts. The legacy exporter (retrace=False) happened to keep the placeholder name equal to the constants key, so it worked by accident; the modern exporter is the one that breaks.The fix
Resolve the placeholder via the canonical
ExportGraphSignature.inputs_to_lifted_custom_objsmapping (placeholder name -> FQN), then readconstants[fqn]. Keep the direct name/target lookup only as a fallback for legacy programs that lack that mapping, and unwrap aFakeScriptObjectto its real object so it can be serialized. The same buggy lookup existed in two places (the save serializer in_compile.pyand the backend engine-info extractor inexecutorch/backend.py); both now share one helper indynamo/_exporter.py.Repro
Before: raises the error above. After: writes a
.ptecontaining both the TensorRT and CUDA delegates.Tests
Adds CPU-only unit tests for the resolver in
tests/py/dynamo/executorch/test_api.py(no GPU or ExecuTorch runtime required): signature-FQN resolution, legacy fallback, present-but-incomplete mapping ->None, missing ->None, andFakeScriptObjectunwrap.