torch2trt
torch2trt copied to clipboard
torch_dim_to_trt_axes does not handle dim=-1 correctly
Description:
As the official PyTorch documentation, The default value of dim
for GumbelSoftmax operator is -1, representing the last dimension. However, the torch_dim_to_trt_axes
function does not handle the case when dim is set to -1, representing the last dimension. This results in incorrect behavior when converting the dim value to a TensorRT axes bitmask.
Reproduce:
Here is a minimal script to reproduce the issue:
import torch
from torch.nn import Module
from torch2trt import torch2trt
para_0 = torch.randn([5, 5], dtype=torch.float32).cuda()
para_1 = 2.0
para_2 = True
class gumbel_softmax(Module):
def forward(self, *args):
return torch.nn.functional.gumbel_softmax(args[0], para_1,para_2,)
model = gumbel_softmax().float().eval().cuda()
model_trt = torch2trt(model, [para_0])
The traceback information is as below:
Traceback (most recent call last):
...
model_trt = torch2trt(model, [para_0])
File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch2trt-0.4.0-py3.9.egg/torch2trt/torch2trt.py", line 778, in torch2trt
outputs = module(*inputs)
File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
result = forward_call(*args, **kwargs)
File "/share/OPERA/torch/reproduce/test_3138.py", line 11, in forward
return torch.nn.functional.gumbel_softmax(args[0], para_1,para_2,)
File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch2trt-0.4.0-py3.9.egg/torch2trt/torch2trt.py", line 300, in wrapper
outputs = method(*args, **kwargs)
File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch/nn/functional.py", line 1915, in gumbel_softmax
index = y_soft.max(dim, keepdim=True)[1]
File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch2trt-0.4.0-py3.9.egg/torch2trt/torch2trt.py", line 309, in wrapper
converter["converter"](ctx)
File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch2trt-0.4.0-py3.9.egg/torch2trt/converters/max.py", line 36, in convert_max
__convert_max_reduce(ctx)
File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch2trt-0.4.0-py3.9.egg/torch2trt/converters/max.py", line 26, in __convert_max_reduce
layer = ctx.network.add_reduce(input_trt, trt.ReduceOperation.MAX, torch_dim_to_trt_axes(dim), keepdim)
File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch2trt-0.4.0-py3.9.egg/torch2trt/torch2trt.py", line 116, in torch_dim_to_trt_axes
axes |= 1 << d
ValueError: negative shift count
Environment
- torch: 2.1.1
- torch2trt: 0.4.0
- tensorrt: 8.6.1