torch2trt
torch2trt copied to clipboard
[Bug] `torch.Tensor.__getitem__` fails when indexing/slicing with a single element, with error: `TypeError: 'int' object is not iterable`
Problem
The torch2trt conversion for torch.Tensor.__getitem__ fails when indexing/slicing for single element; eg.
tensor = torch.rand(2, 3)
tensor[0]
The issue is in torch2trt/converters/getitem.py, where slices is assumed to be iterable. This is not necessarily true in the aforementioned use case, where slices will actually be a single element (specifically the int given as the indexing argument).
Script
Running the following script getitem-element.py using NGC 22.06-py3:
import logging
import tensorrt
import torch
import torch2trt
logging.basicConfig(level=logging.INFO)
torch.manual_seed(0)
DEVICE = 'cuda:0'
TENSOR = torch.rand(2, 3).to(DEVICE)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, tensor):
return tensor[0]
if __name__ == "__main__":
model = Model().eval().to(DEVICE)
out = model(TENSOR)
print(f'Expected model output: {out}')
model_trt = torch2trt.torch2trt(
model, [TENSOR], max_batch_size=TENSOR.shape[0], log_level=tensorrt.Logger.INFO
)
out = model_trt(TENSOR)
print(f'TRT model output: {out}')
produces the following output:
root@8f319e91dd9a:/opt# python /scripts/getitem-element.py
Expected model output: tensor([0.4963, 0.7682, 0.0885], device='cuda:0')
[07/22/2022-23:49:16] [TRT] [I] [MemUsageChange] Init CUDA: CPU +464, GPU +0, now: CPU 1299, GPU 817 (MiB)
[07/22/2022-23:49:16] [TRT] [I] [MemUsageSnapshot] Begin constructing builder kernel library: CPU 1299 MiB, GPU 817 MiB
[07/22/2022-23:49:16] [TRT] [I] [MemUsageSnapshot] End constructing builder kernel library: CPU 1453 MiB, GPU 859 MiB
Traceback (most recent call last):
File "/scripts/getitem-element.py", line 27, in <module>
model_trt = torch2trt.torch2trt(
File "/opt/conda/lib/python3.8/site-packages/torch2trt-0.4.0-py3.8.egg/torch2trt/torch2trt.py", line 736, in torch2trt
outputs = module(*inputs)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1148, in _call_impl
result = forward_call(*input, **kwargs)
File "/scripts/getitem-element.py", line 19, in forward
return tensor[0]
File "/opt/conda/lib/python3.8/site-packages/torch2trt-0.4.0-py3.8.egg/torch2trt/torch2trt.py", line 307, in wrapper
converter["converter"](ctx)
File "/opt/conda/lib/python3.8/site-packages/torch2trt-0.4.0-py3.8.egg/torch2trt/converters/getitem.py", line 34, in convert_tensor_getitem
num_ellipsis = len(input.shape) - num_slice_types(slices)
File "/opt/conda/lib/python3.8/site-packages/torch2trt-0.4.0-py3.8.egg/torch2trt/converters/getitem.py", line 18, in num_slice_types
for s in slices:
TypeError: 'int' object is not iterable
This is likely the same issue presented in #247.
I'll post a solution shortly.
I believe the following should work:
We should be able to convert slices into a tuple if it is not already one, and consume that as the iterable input as previously.
I believe this follows PyTorch behaviors correctly as well; specifically,
tensor[(0,)] == tensor[0]
tensor[(0, 1)] == tensor[0][1]
tensor[(0, 1), 0] == tensor[[0, 1], 0] # This case isn't handled yet; see #755
~~Looks like there might be issues with the : and ... arguments as well, when used on the first dim of the tensor.~~
~~Caused by #769~~