torch2trt icon indicating copy to clipboard operation
torch2trt copied to clipboard

[Bug/Feature Request] `torch.Tensor.__getitem__` doesn't support list or `torch.Tensor` arguments for indexing/slicing when TRT converting

Open chaoz-dev opened this issue 3 years ago • 5 comments

torch.Tensor.__getitem__ supports list or torch.Tensor inputs for indexing, which appears to not convert correctly currently in torch2trt.

In other words, we'd like to perform the following ops, which currently fails or produces incorrect outputs after torch2trt conversion:

tensor = torch.randn(2, 3)
tensor[[1, 1, 1]] # index using list
tensor[torch.tensor([1, 1, 1])] # index using tensor

This can be demonstrated with the following script:

getitem-tensors.py:

  import logging    
  import tensorrt    
  import torch    
  import torch2trt    
      
      
  logging.basicConfig(level=logging.INFO)    
  torch.manual_seed(0)    
      
  DEVICE = torch.device("cuda:0")    
  SHAPE = (2, 3)    
      
      
  class ListIndexModel(torch.nn.Module):    
      def __init__(self):    
          super().__init__()    
      
      def forward(self, a):    
          return a[[1, 1, 1, 1]]    
      
      
  class TensorIndexModel(torch.nn.Module):    
      def __init__(self):    
          super().__init__()    
          self.index_tensor = torch.tensor([1, 1, 1, 1]).to(DEVICE)    
      
      def forward(self, a):    
          return a[self.index_tensor]    
      
      
  if __name__ == "__main__":    
      tensor = torch.randn(SHAPE, dtype=torch.float32, device=DEVICE)    
      print(f'Input: {tensor}')    
      
      list_model = ListIndexModel().eval().to(DEVICE)    
      list_out = list_model(tensor)    
      print(f'List index model: {list_out}')    
      
      try:    
          list_model_trt = torch2trt.torch2trt(    
              list_model, [tensor], max_batch_size=SHAPE[0], log_level=tensorrt.Logger.INFO    
          )    
          list_out_trt = list_model_trt(tensor)    
          print(f'TRT List index model: {list_out_trt}')    
      except:    
          print('ERROR: List index model failed to convert.')    
      
      tensor_model = TensorIndexModel().eval().to(DEVICE)    
      tensor_out = tensor_model(tensor)    
      print(f'Tensor index model: {tensor_out}')

      try:    
          tensor_model_trt = torch2trt.torch2trt(    
              tensor_model, [tensor], max_batch_size=SHAPE[0], log_level=tensorrt.Logger.INFO    
          )    
          tensor_out_trt = tensor_model_trt(tensor)    
          print(f'TRT Tensor index model: {tensor_out_trt}')
      except:
          print('ERROR: Tensor index model failed to convert.')

Outputs:

(torch2trt-master) ~/workspace/pytorch-scratch/torch2trt $ python getitem-tensors.py
Input: tensor([[-0.9247, -0.4253, -2.6438],
        [ 0.1452, -0.1209, -0.5797]], device='cuda:0')
List index model: tensor([[ 0.1452, -0.1209, -0.5797],
        [ 0.1452, -0.1209, -0.5797],
        [ 0.1452, -0.1209, -0.5797],
        [ 0.1452, -0.1209, -0.5797]], device='cuda:0')
[06/30/2022-02:52:43] [TRT] [I] [MemUsageChange] Init CUDA: CPU +348, GPU +0, now: CPU 2986, GPU 6116 (MiB)
[06/30/2022-02:52:44] [TRT] [I] [MemUsageSnapshot] Begin constructing builder kernel library: CPU 3003 MiB, GPU 6116 MiB
[06/30/2022-02:52:44] [TRT] [I] [MemUsageSnapshot] End constructing builder kernel library: CPU 3378 MiB, GPU 6240 MiB
[06/30/2022-02:52:44] [TRT] [W] Tensor DataType is determined at build time for tensors not marked as input or output.
[06/30/2022-02:52:44] [TRT] [E] 4: [graphShapeAnalyzer.cpp::analyzeShapes::1300] Error Code 4: Miscellaneous (IShuffleLayer :1:SHUFFLE:GPU: reshape changes volume. Reshaping [1,1] to [4,3].)
[06/30/2022-02:52:44] [TRT] [E] 4: [graphShapeAnalyzer.cpp::analyzeShapes::1300] Error Code 4: Miscellaneous (IShuffleLayer :1:SHUFFLE:GPU: reshape changes volume. Reshaping [1,1] to [4,3].)
[06/30/2022-02:52:44] [TRT] [E] 4: [graphShapeAnalyzer.cpp::analyzeShapes::1300] Error Code 4: Miscellaneous (IShuffleLayer :1:SHUFFLE:GPU: reshape changes volume. Reshaping [1,1] to [4,3].)
[06/30/2022-02:52:44] [TRT] [E] 4: :1:SHUFFLE:GPU: volume mismatch. Input dimensions [1,1] have volume 1 and output dimensions [4,3] have volume 12.
[06/30/2022-02:52:44] [TRT] [E] 4: [network.cpp::validate::2963] Error Code 4: Internal Error (Layer :1:SHUFFLE:GPU failed validation)
ERROR: List index model failed to convert.
Tensor index model: tensor([[ 0.1452, -0.1209, -0.5797],
        [ 0.1452, -0.1209, -0.5797],
        [ 0.1452, -0.1209, -0.5797],
        [ 0.1452, -0.1209, -0.5797]], device='cuda:0')
[06/30/2022-02:52:44] [TRT] [I] The logger passed into createInferBuilder differs from one already provided for an existing builder, runtime, or refitter. Uses of the global logger, returned by nvinfer1::getLogger(), will return the existing value.

[06/30/2022-02:52:44] [TRT] [I] [MemUsageChange] Init CUDA: CPU +0, GPU +0, now: CPU 3378, GPU 6240 (MiB)
[06/30/2022-02:52:44] [TRT] [W] Tensor DataType is determined at build time for tensors not marked as input or output.
[06/30/2022-02:52:45] [TRT] [W] TensorRT was linked against cuBLAS/cuBLAS LT 11.8.0 but loaded cuBLAS/cuBLAS LT 110.9.2
[06/30/2022-02:52:45] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +873, GPU +378, now: CPU 4251, GPU 6618 (MiB)
[06/30/2022-02:52:45] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +127, GPU +60, now: CPU 4378, GPU 6678 (MiB)
[06/30/2022-02:52:45] [TRT] [I] Local timing cache in use. Profiling results in this builder pass will not be stored.
[06/30/2022-02:52:45] [TRT] [I] Detected 1 inputs and 1 output network tensors.
[06/30/2022-02:52:45] [TRT] [I] Total Host Persistent Memory: 0
[06/30/2022-02:52:45] [TRT] [I] Total Device Persistent Memory: 0
[06/30/2022-02:52:45] [TRT] [I] Total Scratch Memory: 0
[06/30/2022-02:52:45] [TRT] [I] [MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 0 MiB, GPU 4 MiB
[06/30/2022-02:52:45] [TRT] [I] Total Activation Memory: 0
[06/30/2022-02:52:45] [TRT] [I] [MemUsageChange] TensorRT-managed allocation in building engine: CPU +0, GPU +0, now: CPU 0, GPU 0 (MiB)
[06/30/2022-02:52:45] [TRT] [I] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +0, now: CPU 0, GPU 0 (MiB)
TRT Tensor index model: tensor([[-0.9247, -0.4253, -2.6438],
        [ 0.1452, -0.1209, -0.5797]], device='cuda:0')

Note the following from the output: Input tensor:

Input: tensor([[-0.9247, -0.4253, -2.6438],
        [ 0.1452, -0.1209, -0.5797]], device='cuda:0')

Expected output (regardless of whether we're indexing using lists or tensors):

tensor([[ 0.1452, -0.1209, -0.5797],
        [ 0.1452, -0.1209, -0.5797],
        [ 0.1452, -0.1209, -0.5797],
        [ 0.1452, -0.1209, -0.5797]], device='cuda:0')

Output using list to index:

[TRT] [E] 4: :1:SHUFFLE:GPU: volume mismatch. Input dimensions [1,1] have volume 1 and output dimensions [4,3] have volume 12.

Output using tensor to index:

tensor([[-0.9247, -0.4253, -2.6438],
        [ 0.1452, -0.1209, -0.5797]], device='cuda:0')

Digging into the PyTorch output some more, I believe that PyTorch is creating a new tensor for this __getitem__ op variation (which would make sense, since we can create larger/different shaped tensors than what we're originally indexing, and the storage is different). This can be observed using data_ptr() to check the mem address of the indexed tensor:

>>> grid
tensor([[0.3983, 0.6135, 0.4320, 0.4636],
        [0.8307, 0.1006, 0.6228, 0.5191],
        [0.5090, 0.7910, 0.4277, 0.7855]])
>>> grid.data_ptr()
93965479906752
>>> grid.storage()
 0.3983076214790344
 0.6135425567626953
 0.4319940209388733
 0.4636090397834778
 0.830711841583252
 0.10059648752212524
 0.6227966547012329
 0.5191217064857483
 0.5089586973190308
 0.7910491228103638
 0.42765843868255615
 0.7855058312416077
[torch.FloatStorage of size 12]
>>> grid[[1,]]
tensor([[0.8307, 0.1006, 0.6228, 0.5191]])
>>> grid[[1,]].data_ptr()
93965479919232
>>> grid[[1,]].storage()
 0.830711841583252
 0.10059648752212524
 0.6227966547012329
 0.5191217064857483
[torch.FloatStorage of size 4]

To replicate this behavior in TRT/torch2trt, we'll likely need to add the ability to handle list/tensor inputs as index arguments, and produce new tensors as the output.

chaoz-dev avatar Jun 30 '22 16:06 chaoz-dev

Above should be reproducible on NGC container pytorch:22.06-py3

chaoz-dev avatar Jun 30 '22 18:06 chaoz-dev

Currently working on this implementation, but it's a bit tricky.

chaoz-dev avatar Jul 21 '22 02:07 chaoz-dev

Adding to this, tuple arguments have the same behavior iff combined with ... or :. ie.

t = torch.rand(3, 4, 5)
t[(1, 2)] == t[1][2] # This should be handled by #768 and generally
t[(1, 2), ...] == t[(1, 2), :] == t[[1, 2]] # Not currently supported; ie. this issue.

chaoz-dev avatar Jul 23 '22 04:07 chaoz-dev

Alright I think I have an algorithm that should work in the general case, I'll document it on the PR...

chaoz-dev avatar Jul 30 '22 19:07 chaoz-dev

https://github.com/NVIDIA-AI-IOT/torch2trt/pull/783

chaoz-dev avatar Aug 04 '22 22:08 chaoz-dev