torch2trt icon indicating copy to clipboard operation
torch2trt copied to clipboard

[Bug] `torch.Tensor.__getitem__` fails when indexing/slicing with a single element, with error: `TypeError: 'int' object is not iterable`

Open chaoz-dev opened this issue 3 years ago • 3 comments

Problem The torch2trt conversion for torch.Tensor.__getitem__ fails when indexing/slicing for single element; eg.

tensor = torch.rand(2, 3)
tensor[0]

The issue is in torch2trt/converters/getitem.py, where slices is assumed to be iterable. This is not necessarily true in the aforementioned use case, where slices will actually be a single element (specifically the int given as the indexing argument).

Script Running the following script getitem-element.py using NGC 22.06-py3:

  import logging    
  import tensorrt    
  import torch    
  import torch2trt    
      
      
  logging.basicConfig(level=logging.INFO)    
  torch.manual_seed(0)    
      
  DEVICE = 'cuda:0'    
  TENSOR = torch.rand(2, 3).to(DEVICE)    
      
      
  class Model(torch.nn.Module):    
      def __init__(self):    
          super().__init__()    
      
      def forward(self, tensor):    
          return tensor[0]    
      
      
  if __name__ == "__main__":    
      model = Model().eval().to(DEVICE)    
      out = model(TENSOR)    
      print(f'Expected model output: {out}')    
      
      model_trt = torch2trt.torch2trt(    
          model, [TENSOR], max_batch_size=TENSOR.shape[0], log_level=tensorrt.Logger.INFO    
      )    
      out = model_trt(TENSOR)    
      print(f'TRT model output: {out}')    

produces the following output:

root@8f319e91dd9a:/opt# python /scripts/getitem-element.py
Expected model output: tensor([0.4963, 0.7682, 0.0885], device='cuda:0')
[07/22/2022-23:49:16] [TRT] [I] [MemUsageChange] Init CUDA: CPU +464, GPU +0, now: CPU 1299, GPU 817 (MiB)
[07/22/2022-23:49:16] [TRT] [I] [MemUsageSnapshot] Begin constructing builder kernel library: CPU 1299 MiB, GPU 817 MiB
[07/22/2022-23:49:16] [TRT] [I] [MemUsageSnapshot] End constructing builder kernel library: CPU 1453 MiB, GPU 859 MiB
Traceback (most recent call last):
  File "/scripts/getitem-element.py", line 27, in <module>
    model_trt = torch2trt.torch2trt(
  File "/opt/conda/lib/python3.8/site-packages/torch2trt-0.4.0-py3.8.egg/torch2trt/torch2trt.py", line 736, in torch2trt
    outputs = module(*inputs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1148, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/scripts/getitem-element.py", line 19, in forward
    return tensor[0]
  File "/opt/conda/lib/python3.8/site-packages/torch2trt-0.4.0-py3.8.egg/torch2trt/torch2trt.py", line 307, in wrapper
    converter["converter"](ctx)
  File "/opt/conda/lib/python3.8/site-packages/torch2trt-0.4.0-py3.8.egg/torch2trt/converters/getitem.py", line 34, in convert_tensor_getitem
    num_ellipsis = len(input.shape) - num_slice_types(slices)
  File "/opt/conda/lib/python3.8/site-packages/torch2trt-0.4.0-py3.8.egg/torch2trt/converters/getitem.py", line 18, in num_slice_types
    for s in slices:
TypeError: 'int' object is not iterable

chaoz-dev avatar Jul 22 '22 23:07 chaoz-dev

This is likely the same issue presented in #247.

chaoz-dev avatar Jul 22 '22 23:07 chaoz-dev

I'll post a solution shortly.

I believe the following should work: We should be able to convert slices into a tuple if it is not already one, and consume that as the iterable input as previously.

I believe this follows PyTorch behaviors correctly as well; specifically,

tensor[(0,)] == tensor[0]
tensor[(0, 1)] == tensor[0][1]
tensor[(0, 1), 0] == tensor[[0, 1], 0] # This case isn't handled yet; see #755 

chaoz-dev avatar Jul 23 '22 00:07 chaoz-dev

~~Looks like there might be issues with the : and ... arguments as well, when used on the first dim of the tensor.~~

~~Caused by #769~~

chaoz-dev avatar Jul 23 '22 02:07 chaoz-dev