torch2trt icon indicating copy to clipboard operation
torch2trt copied to clipboard

torch_dim_to_trt_axes does not handle dim=-1 correctly

Open Thrsu opened this issue 1 year ago • 0 comments

Description:

As the official PyTorch documentation, The default value of dim for GumbelSoftmax operator is -1, representing the last dimension. However, the torch_dim_to_trt_axes function does not handle the case when dim is set to -1, representing the last dimension. This results in incorrect behavior when converting the dim value to a TensorRT axes bitmask.

Reproduce:

Here is a minimal script to reproduce the issue:

import torch
from torch.nn import Module
from torch2trt import torch2trt

para_0 = torch.randn([5, 5], dtype=torch.float32).cuda()
para_1 = 2.0
para_2 = True
class gumbel_softmax(Module):
    def forward(self, *args):
        return torch.nn.functional.gumbel_softmax(args[0], para_1,para_2,)
model = gumbel_softmax().float().eval().cuda()
model_trt = torch2trt(model, [para_0])

The traceback information is as below:

Traceback (most recent call last):
  ...
   model_trt = torch2trt(model, [para_0])
  File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch2trt-0.4.0-py3.9.egg/torch2trt/torch2trt.py", line 778, in torch2trt
    outputs = module(*inputs)
  File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/share/OPERA/torch/reproduce/test_3138.py", line 11, in forward
    return torch.nn.functional.gumbel_softmax(args[0], para_1,para_2,)
  File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch2trt-0.4.0-py3.9.egg/torch2trt/torch2trt.py", line 300, in wrapper
    outputs = method(*args, **kwargs)
  File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch/nn/functional.py", line 1915, in gumbel_softmax
    index = y_soft.max(dim, keepdim=True)[1]
  File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch2trt-0.4.0-py3.9.egg/torch2trt/torch2trt.py", line 309, in wrapper
    converter["converter"](ctx)
  File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch2trt-0.4.0-py3.9.egg/torch2trt/converters/max.py", line 36, in convert_max
    __convert_max_reduce(ctx)
  File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch2trt-0.4.0-py3.9.egg/torch2trt/converters/max.py", line 26, in __convert_max_reduce
    layer = ctx.network.add_reduce(input_trt,  trt.ReduceOperation.MAX, torch_dim_to_trt_axes(dim), keepdim)
  File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch2trt-0.4.0-py3.9.egg/torch2trt/torch2trt.py", line 116, in torch_dim_to_trt_axes
    axes |= 1 << d 
ValueError: negative shift count

Environment

  • torch: 2.1.1
  • torch2trt: 0.4.0
  • tensorrt: 8.6.1

Thrsu avatar Nov 27 '23 18:11 Thrsu