torch2trt icon indicating copy to clipboard operation
torch2trt copied to clipboard

[Bug] Comparison between tensor and scalar fails with error `[elementWiseNode.cpp::computeOutputExtents::19] Error Code 2: Internal Error (Assertion x.nbDims == y.nbDims failed. )`

Open chaoz-dev opened this issue 2 years ago • 1 comments

Comparing a torch.Tensor with a scalar value fails.

Running the following script:

  import logging
  import tensorrt
  import torch
  import torch2trt

  from typing import List


  logging.basicConfig(level=logging.INFO)
  torch.manual_seed(0)


  class ScalarModule(torch.nn.Module):
      def __init__(self) -> None:
          super().__init__()
          self.scalar = 1e-3

      def forward(self, t: torch.Tensor) -> List[torch.Tensor]:
          mask = self.scalar < t
          return [mask]


  if __name__ == "__main__":
      tensor = torch.rand(2,2).cuda()

      model = ScalarModule().cuda().eval()
      out = model(tensor).pop()
      print(f"Out {out}")

      model_trt = torch2trt.torch2trt(
          model, [tensor], log_level=tensorrt.Logger.INFO
      )
      out_trt = model_trt(tensor).pop()
      print(f"Out TRT {out_trt}")

      assert torch.allclose(out, out_trt), "Not all close!"
      print("All close!")

Outputs the following:

[02/03/2023-15:30:05] [TRT] [E] 2: [elementWiseNode.cpp::computeOutputExtents::19] Error Code 2: Internal Error (Assertion x.nbDims == y.nbDims failed. )
[02/03/2023-15:30:05] [TRT] [E] 2: [elementWiseNode.cpp::computeOutputExtents::19] Error Code 2: Internal Error (Assertion x.nbDims == y.nbDims failed. )
[02/03/2023-15:30:05] [TRT] [E] 2: [elementWiseNode.cpp::computeOutputExtents::19] Error Code 2: Internal Error (Assertion x.nbDims == y.nbDims failed. )
[02/03/2023-15:30:05] [TRT] [E] 4: [layers.cpp::validate::2443] Error Code 4: Internal Error (:1:ELEMENTWISE:GPU: elementwise inputs must have same dimensions or follow broadcast rules (input dimensions were [2,2] and [1]).)

chaoz-dev avatar Feb 03 '23 22:02 chaoz-dev

I'll put up a fix shortly.

chaoz-dev avatar Feb 03 '23 22:02 chaoz-dev