TensorRT
TensorRT copied to clipboard
🐛 [Bug] __len__ should return > 0 when using torch_tensorrt.compile with torch_tensorrt.Input
Bug Description
When using torch_tensorrt.Input and torch_tensorrt.compile, I get an error:
File "/mnt/c/Coding/Testing/PyTorch/MultiClassImageClassification/src/compressmodel.py", line 48, in <module> trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=[input], enabled_precisions = {torch.half, torch.float}, output_format="torchscript") ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/miniconda3/envs/multilabelimage_model_env/lib/python3.11/site-packages/torch_tensorrt/_compile.py", line 228, in compile trt_graph_module = dynamo_compile( ^^^^^^^^^^^^^^^ File "/root/miniconda3/envs/multilabelimage_model_env/lib/python3.11/site-packages/torch_tensorrt/dynamo/_compiler.py", line 236, in compile trt_gm = compile_module(gm, inputs, settings) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/miniconda3/envs/multilabelimage_model_env/lib/python3.11/site-packages/torch_tensorrt/dynamo/_compiler.py", line 346, in compile_module trt_module = convert_module( ^^^^^^^^^^^^^^^ File "/root/miniconda3/envs/multilabelimage_model_env/lib/python3.11/site-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 56, in convert_module interpreter_result = interpreter.run() ^^^^^^^^^^^^^^^^^ File "/root/miniconda3/envs/multilabelimage_model_env/lib/python3.11/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 152, in run super().run() File "/root/miniconda3/envs/multilabelimage_model_env/lib/python3.11/site-packages/torch/fx/interpreter.py", line 138, in run self.env[node] = self.run_node(node) ^^^^^^^^^^^^^^^^^^^ File "/root/miniconda3/envs/multilabelimage_model_env/lib/python3.11/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 276, in run_node trt_node: torch.fx.Node = super().run_node(n) ^^^^^^^^^^^^^^^^^^^ File "/root/miniconda3/envs/multilabelimage_model_env/lib/python3.11/site-packages/torch/fx/interpreter.py", line 195, in run_node return getattr(self, n.op)(n.target, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/miniconda3/envs/multilabelimage_model_env/lib/python3.11/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 362, in call_function return converter(self.ctx, target, args, kwargs, self._cur_node_name) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/miniconda3/envs/multilabelimage_model_env/lib/python3.11/site-packages/torch_tensorrt/dynamo/conversion/aten_ops_converters.py", line 103, in aten_ops_batch_norm_legit_no_training return impl.normalization.batch_norm( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/miniconda3/envs/multilabelimage_model_env/lib/python3.11/site-packages/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py", line 60, in batch_norm if not ctx.net.has_implicit_batch_dimension and len(input.shape) < 4: ^^^^^^^^^^^^^^^^ ValueError: __len__() should return >= 0
This does not occur when directly passing in sample images but I cannot use that approach because then I cannot specify a variable batch size
To Reproduce
Steps to reproduce the behavior:
1.Define an input:
input = torch_tensorrt.Input( min_shape=(1, 3, config.model_image_size, config.model_image_size), opt_shape=(16, 3, config.model_image_size, config.model_image_size), max_shape=(16, 3, config.model_image_size, config.model_image_size), dtype=torch.half, name="x")
2. Use the input in torch_tensorrt.compile and observe the error:
with autocast(enabled=True): with torch_tensorrt.logging.debug(): trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=[input], enabled_precisions = {torch.half, torch.float}, output_format="torchscript")
Expected behavior
torch_tensorrt.Input would work as expected.
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0): 2.2.0
- PyTorch Version (e.g. 1.0): 2.2.1
- CPU Architecture: Intel
- OS (e.g., Linux): WSL
- How you installed PyTorch (
conda
,pip
,libtorch
, source): conda - Build command you used (if compiling from source):
- Are you using local sources or building from archives: directly from pip
- Python version: 3.11
- CUDA version: 12.1
- GPU models and configuration: RTX 3090
- Any other relevant information:
@Skier23 Could you provide your model?