TensorRT icon indicating copy to clipboard operation
TensorRT copied to clipboard

✨[Feature] Support `list` and `namedtuple` input types to `forward` function

Open chaoz-dev opened this issue 3 years ago • 8 comments

Is your feature request related to a problem? Please describe.

Currently, the forward function only supports tensor input types when compiling. However, sometimes we wish to supply many tensors into the forward function at once (say, greater than 10); this results in a very long forward API call where we have to list every tensor individually when calling forward. It would be helpful if we could pass in a single container containing these tensors all at once instead, which results in a much cleaner API call.

For this specific request, I focus on the list and namedtuple input types first, since these should cover most basic uses cases (and should functionally satisfy named tensor key-value pair type inputs).

Describe the solution you'd like

Instead of supporting only the following, where we need to supply torch.Tensors into forward:

  DEVICE = torch.device("cuda:0")                                                                                            
  SHAPE = (1, 1)        

  torch.manual_seed(0)                                                                                                                                                                                                                                                                                                                                 

  class Model(torch.nn.Module):                                                                                              
      def __init__(self):                                                                                                    
          super().__init__()                                                                                                 
                                                                                                                          
      def forward(self, a, b):                                                                                               
          return a - b                                                                                      

  if __name__ == "__main__":                                                                                                 
      tensor = torch.randn(SHAPE, dtype=torch.float32, device=DEVICE)                                                        
                                                                                                                             
      model = Model().eval().to(DEVICE)                                                                                      
      out = model(tensor, tensor)                                                                                                   
                                                                                                                                                                                                                                                          
      model_trt = torch_tensorrt.compile(                                                                                    
          model,                                                                                                             
          inputs=[                                                                                                           
              torch_tensorrt.Input(shape=SHAPE),                                                                             
              torch_tensorrt.Input(shape=SHAPE),                                                                             
          ],                                                                                                                 
          enabled_precisions={torch.float},                                                                                  
      )                                                                                                                      
      out_trt = model(tensor, tensor)                                                                                               
                                                                                                                                                                                                                                                         
      assert torch.max(torch.abs(out - out_trt)) < 1e-6                                                                      

Support also inputting namedtuple or list into forward:

  DEVICE = torch.device("cuda:0")                                                                                            
  SHAPE = (1, 1)        

  torch.manual_seed(0)  

  Input = namedtuple('Input', ['t1', 't2'])                                                                                  
                                                                                                                             
  class Model(torch.nn.Module):                                                                                              
      def __init__(self):                                                                                                    
          super().__init__()                                                                                                 
                                                                                                                          
      def forward(self, input_: Input):                                                                                      
          return input_.t1 - input_.t2                                                                                       

  if __name__ == "__main__":                                                                                                 
      tensor = torch.randn(SHAPE, dtype=torch.float32, device=DEVICE)                                                        
      input_ = Input(tensor, tensor)                                                                                         
                                                                                                                             
      model = Model().eval().to(DEVICE)                                                                                      
      out = model(input_)                                                                                                   
                                                                                                                                                                                                                                                          
      model_trt = torch_tensorrt.compile(                                                                                    
          model,                                                                                                             
          inputs=[                                                                                                           
              torch_tensorrt.Input(shape=SHAPE),                                                                             
              torch_tensorrt.Input(shape=SHAPE),                                                                             
          ],                                                                                                                 
          enabled_precisions={torch.float},                                                                                  
      )                                                                                                                      
      out_trt = model(input_)                                                                                               
                                                                                                                                                                                                                                                         
      assert torch.max(torch.abs(out - out_trt)) < 1e-6                                                                      

Describe alternatives you've considered

Currently the only alternative is to supply tensors directly into the forward function; supplying namedtuples will cause the compilation to segfault, and supplying lists will cause the compilation to fail to recognize the input altogether.

Additional context

  • For simplicity, the input containers should contain ONLY tensors (implying that we disallow nested containers). Containers with mixed input types are ignored.
  • Furthermore, there must be a bijection between the tensors in the container and the sizes provided into the compile call; ie. there must be one Input size for each tensor in the container and both are taken in the same order.
  • We can mix tensors and containers into the forward call (eg. forward(x: torch.Tensor, y: List[torch.Tensor], z: namedtuple[torch.Tensor])). Any other types are treated as they are currently when input.

chaoz-dev avatar Jan 08 '22 08:01 chaoz-dev

@narendasan Let me know if the behaviors listed under Additional context make sense. In particular, I believe we currently ignore other input types going into forward... if we allow them at all?

I can try taking a crack at the implementation here later when I get a chance.

chaoz-dev avatar Jan 08 '22 09:01 chaoz-dev

Ah this might be a duplicate of #428, although this request might be slightly less ambitious.

chaoz-dev avatar Jan 08 '22 09:01 chaoz-dev

@chaoz-dev Yeah this is reasonable. We have been working on a design doc for these sort of features here https://github.com/NVIDIA/Torch-TensorRT/discussions/629. @inocsin Has been working on the first steps here with arbitrary mixes of tuples (since they are fixed size) and tensors as inputs and outputs. Need to check with him on if he has a public dev branch but help here is greatly appreciated.

narendasan avatar Jan 08 '22 19:01 narendasan

Sounds good, I'll take a look at the design doc and make some suggestions there for review. I had a quick look at the code and my naive first pass at this is to unpack input containers in torch_tensorrt/ts/_compiler.py in the compile function before it hits the actual compilation step, so the compilation always sees a flat list of tensors... I believe this should satisfy the basic aspects of inputting an iterable container of tensors.

chaoz-dev avatar Jan 08 '22 23:01 chaoz-dev

Seems reasonable to take the step of adding support for one collection of inputs of any type. But we need to do this in compiler.cpp since we need to support C++ and Python APIs as well as we need to be able to construct a new module with the correct interface otherwise users cannot reuse the same input formatting code in their applications.

narendasan avatar Jan 11 '22 00:01 narendasan

Yeah that makes sense. I'll take a look at this shortly.

chaoz-dev avatar Jan 11 '22 04:01 chaoz-dev

Deferring to @inocsin in #629 here

chaoz-dev avatar Feb 15 '22 17:02 chaoz-dev

This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days

github-actions[bot] avatar May 17 '22 00:05 github-actions[bot]

Initial feature support has been merged

narendasan avatar Sep 02 '22 18:09 narendasan