TensorRT
TensorRT copied to clipboard
🐛 [Bug] torch.outer doesn't compile correctly with dynamo backend
Bug Description
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: arg0_1 [shape=[10], dtype=DataType.FLOAT]
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node __/reshape_default (kind: aten.reshape.default, args: ('arg0_1 <tensorrt.ITensor [shape=(10,), dtype=DataType.FLOAT]>', [10, 1]))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node __/mul (kind: aten.mul.Tensor, args: ('[SHUFFLE]-[aten_ops.reshape.default]-[__/reshape_default]_output <tensorrt.ITensor [shape=(10, 1), dtype=DataType.FLOAT]>', '<torch.Tensor as np.ndarray [shape=(5,), dtype=int64]>'))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node __/select (kind: aten.select.int, args: ('[ELEMENTWISE]-[aten_ops.mul.Tensor]-[__/mul]_output_mul.Tensor <tensorrt.ITensor [shape=(10, 1), dtype=DataType.FLOAT]>', 0, 9))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node __/select_1 (kind: aten.select.int, args: ('(Unnamed Layer* 7) [Gather]_output <tensorrt.ITensor [shape=(1,), dtype=DataType.FLOAT]>', 0, 4))
RuntimeError: cannot have index greater than the dimension length! 4
To Reproduce
Steps to reproduce the behavior:
import torch
import torch_tensorrt
import logging
import torch_tensorrt.dynamo.conversion._TRTInterpreter
torch_tensorrt.dynamo.conversion._TRTInterpreter._LOGGER.setLevel(logging.DEBUG)
class MyModel(torch.nn.Module):
def forward(self, input):
a = torch.outer(input, torch.arange(5).cuda())
# Workaround ok
# a = input.view(-1, 1) @ torch.arange(5, dtype=torch.float32).cuda().view(1, -1)
return a[a.shape[0]-1, 4]
model = MyModel().eval().cuda()
sample_input = torch.arange(10, dtype=torch.float32).cuda()
inputs = (sample_input,)
print(model(*inputs))
torch_tensorrt.compile(model, ir='dynamo', inputs=inputs, min_block_size=2)
Expected behavior
Should compile correctly. Executing the model with the same inputs is successful.
Environment
Python 3.11.2 torch 2.3.1+cu121 torch_tensorrt 2.3.0+cu121