torch2trt icon indicating copy to clipboard operation
torch2trt copied to clipboard

[Bug] `torch.clamp` fails to convert with error `TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.`

Open chaoz-dev opened this issue 2 years ago • 3 comments

The torch.clamp converter currently fails due to a calculation performed between a (potentially) device allocated tensor and a cpu tensor.

Running the following script:

  import logging
  import tensorrt
  import torch
  import torch2trt


  logging.basicConfig(level=logging.INFO)
  torch.manual_seed(0)


  class ClampModule(torch.nn.Module):
      def __init__(self, lower_bound: torch.Tensor, upper_bound: torch.Tensor):
          super().__init__()
          self.lower_bound = lower_bound
          self.upper_bound = upper_bound

      def forward(self, t: torch.Tensor):
          return torch.clamp(t, min=self.lower_bound, max=self.upper_bound)


  if __name__ == "__main__":
      lower_bound = torch.tensor((1, 1), dtype=torch.float32).cuda()
      upper_bound = torch.tensor((5, 5), dtype=torch.float32).cuda()
      tensor = torch.tensor((3, 3), dtype=torch.float32).cuda()

      model = ClampModule(lower_bound, upper_bound).cuda().eval()
      out = model(tensor)

      model_trt = torch2trt.torch2trt(
          model,
          [tensor],
      )
      out_trt = model_trt(tensor)
      assert torch.allclose(out, out_trt), "Not all close"

Outputs the following error:

TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

chaoz-dev avatar Jan 31 '23 21:01 chaoz-dev

I'll put up a fix shortly.

chaoz-dev avatar Jan 31 '23 21:01 chaoz-dev

Ah I see, the bigger issue is that the torch.clamp implementations currently only handle single scalars instead of torch.tensors for min and max args (at least, that seems to be the intention from the unit tests). I'll adjust the implementation.

chaoz-dev avatar Jan 31 '23 22:01 chaoz-dev

我遇到相同問題 怎麼解決的

Alex-fishred avatar Feb 29 '24 10:02 Alex-fishred