TensorRT
TensorRT copied to clipboard
fix: replace add_identity by add_cast for type cast
Description
This PR updates the type_cast helper function to ensure compatibility with TensorRT's strongly typed network mode.
type_cast used add_identity() followed by set_output_type() to perform the data type cast. However, in strongly typed mode, calling set_output_type() on the identity layer causes an error below:
ILayer::setOutputType: Error Code 3: API Usage Error (Parameter check failed, condition: !mNetwork->usingStronglyTyped(). INetworkLayer::setOutputType cannot be called for a strongly typed network.)
[graphShapeAnalyzer.cpp::checkCalculationStatusSanity::1962] Error Code 2: Internal Error (Assertion !isInFlight(p.second.symbolicRep) failed. )
type_cast is called by expand function in torch_tensorrt/dynamo/conversion/impl/slice/ops.py with dynamic dimension index.
https://github.com/pytorch/TensorRT/blob/f09be72451200d4f9a347a6141276e81ff2fd22b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py#L232-L237
The following code snippet reproduces the error:
import torch
import torch_tensorrt
from torch.export._trace import _export
from torch_tensorrt.dynamo._compiler import CompilationSettings
from torch_tensorrt.dynamo.conversion import TRTInterpreter
from torch_tensorrt.dynamo.lowering import get_decompositions
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.visual = torch.nn.Linear(10, 10)
def forward(self, input: torch.Tensor):
return input.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0)
model = Model().to("cuda")
x = torch.randn(1, 40).to("cuda")
ep = _export(model, (x,))
ep = ep.run_decompositions(get_decompositions(False))
gm = ep.module()
interpreter = TRTInterpreter(
gm,
[torch_tensorrt.Input(name="input", min_shape=(1, 40), opt_shape=(4, 40), max_shape=(8, 40), dtype=torch.float32)],
compilation_settings=CompilationSettings(use_explicit_typing=True),
)
results = interpreter.run()
To address this, the function now uses add_cast() to explicitly insert a cast layer that converts the input tensor to the desired cast_type.
If there was a specific reason for using add_identity(), please let me know, as this change assumes that the identity layer was not essential beyond type casting.
Type of change
- Bug fix (non-breaking change which fixes an issue)
Checklist:
- [X] My code follows the style guidelines of this project (You can use the linters)
- [X] I have performed a self-review of my own code
- [ ] I have commented my code, particularly in hard-to-understand areas and hacks
- [ ] I have made corresponding changes to the documentation
- [ ] I have added tests to verify my fix or my feature
- [ ] New and existing unit tests pass locally with my changes
- [ ] I have added the relevant labels to my PR in so that relevant reviewers are notified
Thanks @junstar92 for the contribution. Instead of modifying the FX path, we should import these utilities from the dynamo path since it is actively being developed. So, instead can you modify this change so that the prepend_ones is imported from dynamo/conversion/converter_utils instead ?
from torch_tensorrt.dynamo.converters.converter_utils import (
has_dynamic_shape,
prepend_ones,
set_layer_name,
)
LGTM apart from the changes mentioned above
@peri044 @zewenli98 Thanks for the suggestion. As you mentioned, I changed fx's conversion utilities to dynamo's.
@zewenli98
Besides, I noticed that you are using
from torch.export._trace import _exportinstead offrom torch.export import exportin your repro. May I know the reason?
There's no special reason, it's just how I've been doing it.
Also, @junstar92 please rebase with main. Some of the CI failures should be resolved