torch2trt icon indicating copy to clipboard operation
torch2trt copied to clipboard

[Bug] Comparing a tensor with an IntWrapper results in error `TypeError: add_constant(): incompatible function arguments.`

Open chaoz-dev opened this issue 2 years ago • 1 comments

Currently, when we compare a scalar with a tensor, add_missing_trt_tensors will try to convert the scalar to a tensor so that we can then add it as a TRT constant. However, this logic does not yet account for the fact that the scalar could be an IntWrapper, in which case we would want to use the ._trt value already present with the IntWrapper. Instead, it attempts to use the IntWrapper's comparators with the given tensor, which generates nonsensical output that is then passed to add_constant from the original logic.

This can be observed by running the following script:

  import logging
  import tensorrt
  import torch
  import torch2trt

  from typing import List


  logging.basicConfig(level=logging.INFO)
  torch.manual_seed(0)


  class Module(torch.nn.Module):
      def forward(self, t: torch.Tensor) -> List[torch.Tensor]:
          h, w = t.shape
          return [t < float(h)]


  if __name__ == "__main__":
      tensor = torch.rand(3, 3).cuda()

      model = Module().cuda().eval()
      out = model(tensor).pop()
      print(f"Out {out}")

      model_trt = torch2trt.torch2trt(
          model, [tensor], log_level=tensorrt.Logger.INFO, min_shapes=[(1,1)], max_shapes=[(10,10)]
      )
      out_trt = model_trt(tensor).pop()
      print(f"Out TRT {out_trt}")

      assert torch.allclose(out, out_trt), "Not all close!"
      print("All close!")

      tensor = torch.rand(3,3).cuda()

      out = model(tensor).pop()
      out_trt = model_trt(tensor).pop()

      assert torch.allclose(out, out_trt), "Not all close!"
      print("All close!")

Which outputs the following:

TypeError: add_constant(): incompatible function arguments. The following argument types are supported:
    1. (self: tensorrt.tensorrt.INetworkDefinition, shape: tensorrt.tensorrt.Dims, weights: tensorrt.tensorrt.Weights) -> tensorrt.tensorrt.IConstantLayer

Invoked with: <tensorrt.tensorrt.INetworkDefinition object at 0x7fec1b1c80b0>, (1,), 3

chaoz-dev avatar Feb 15 '23 00:02 chaoz-dev

I'll put up a fix shortly.

chaoz-dev avatar Feb 15 '23 00:02 chaoz-dev