torch2trt icon indicating copy to clipboard operation
torch2trt copied to clipboard

[Question/Bug?] Master seems to default to input shapes of [1, ...]. Will this work for ops on the batch dim?

Open chaoz-dev opened this issue 3 years ago • 4 comments

I just saw that v0.4.0 was released a few hours ago. Nice!

Looking through the changes, I noticed the following code snippet:

    # handle inputs as dataset of list of tensors
    if issubclass(inputs.__class__, Dataset):
        dataset = inputs
        if len(dataset) == 0:
            raise ValueError('Dataset must have at least one element to use for inference.')
        inputs = dataset[0]
    else:
        dataset = TensorBatchDataset(inputs)
        inputs = dataset[0]

This seems to default inputs to shape [1, ....] for all inputs again, as we used to do for implicit batch dimension. I believe this was changed with the addition of the explicit batch dimension, prior to the release of v0.4.0?

Will this change work for inputs where the first dim has a value > 1, and we wish to operate on this dim? For example,

t = torch.rand(7, 2, 3, 4)
t[3]

Running the following script does seem to imply that we will have issues here:

  import logging
  import tensorrt
  import torch
  import torch2trt
  
  
  logging.basicConfig(level=logging.INFO)
  torch.manual_seed(0)
  
  DEVICE = 'cuda:0'
  TENSOR = torch.rand(7, 3, 4).to(DEVICE)
  
  
  class Model(torch.nn.Module):
      def __init__(self):
          super().__init__()
  
      def forward(self, tensor):
          return tensor[3]
  
  
  if __name__ == "__main__":
      model = Model().eval().to(DEVICE)
      out = model(TENSOR)
      print(f'Expected model output: {out}')
  
      model_trt = torch2trt.torch2trt(
          model, [TENSOR], max_batch_size=TENSOR.shape[0], log_level=tensorrt.Logger.INFO
      )
      out = model_trt(TENSOR)
      print(f'TRT model output: {out}')

Outputs:

(torch2trt-master-py3.7.7) ~/workspace/pytorch-scratch/torch2trt $ python getitem-element.py
Expected model output: tensor([[0.2081, 0.9298, 0.7231, 0.7423],
        [0.5263, 0.2437, 0.5846, 0.0332],
        [0.1387, 0.2422, 0.8155, 0.7932]], device='cuda:0')
Traceback (most recent call last):
  File "getitem-element.py", line 28, in <module>
    model, [TENSOR], max_batch_size=TENSOR.shape[0], log_level=tensorrt.Logger.INFO
  File "/home/chaoz/.anaconda3/envs/torch2trt-master-py3.7.7/lib/python3.7/site-packages/torch2trt-0.4.0-py3.7.egg/torch2trt/torch2trt.py", line 662, in torch2trt
    outputs = module(*inputs)
  File "/home/chaoz/.anaconda3/envs/torch2trt-master-py3.7.7/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "getitem-element.py", line 19, in forward
    return tensor[3]
IndexError: index 3 is out of bounds for dimension 0 with size 1

chaoz-dev avatar Jul 23 '22 03:07 chaoz-dev

This will block resolving #768

chaoz-dev avatar Jul 23 '22 03:07 chaoz-dev

Hi @chaoz-dev ,

Thanks for pointing this out. I think this is a limitation with the TensorBatchDataset, which assumes the first dimension is a batch dimension. Perhaps a different default behavior would work better.

However, I'm curious if the following works for you?

from torch2trt.dataset import ListDataset

dataset = ListDataset()
dataset.insert([TENSOR])

model_trt = torch2trt(model, dataset)

The ListDataset simply stores the inserted elements in a list, and returns the elements as-is, without indexing an assumed batch dimension. Maybe this would serve a better default.

Best, John

jaybdub avatar Jul 25 '22 20:07 jaybdub

Just for book-keeping, I think this issue is in master. 0.4.0 was tagged prior to merging the dynamic shapes / dataset features.

jaybdub avatar Jul 25 '22 20:07 jaybdub

Ah yes, I stand corrected. It looks like 0.4.0 should still work as previously described; the issue should be in master only. Let me give your suggestions above a try and see if this resolves the issue....

chaoz-dev avatar Jul 29 '22 15:07 chaoz-dev