torch2trt
torch2trt copied to clipboard
[Bug] Comparison between tensor and scalar fails with error `[elementWiseNode.cpp::computeOutputExtents::19] Error Code 2: Internal Error (Assertion x.nbDims == y.nbDims failed. )`
Comparing a torch.Tensor with a scalar value fails.
Running the following script:
import logging
import tensorrt
import torch
import torch2trt
from typing import List
logging.basicConfig(level=logging.INFO)
torch.manual_seed(0)
class ScalarModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.scalar = 1e-3
def forward(self, t: torch.Tensor) -> List[torch.Tensor]:
mask = self.scalar < t
return [mask]
if __name__ == "__main__":
tensor = torch.rand(2,2).cuda()
model = ScalarModule().cuda().eval()
out = model(tensor).pop()
print(f"Out {out}")
model_trt = torch2trt.torch2trt(
model, [tensor], log_level=tensorrt.Logger.INFO
)
out_trt = model_trt(tensor).pop()
print(f"Out TRT {out_trt}")
assert torch.allclose(out, out_trt), "Not all close!"
print("All close!")
Outputs the following:
[02/03/2023-15:30:05] [TRT] [E] 2: [elementWiseNode.cpp::computeOutputExtents::19] Error Code 2: Internal Error (Assertion x.nbDims == y.nbDims failed. )
[02/03/2023-15:30:05] [TRT] [E] 2: [elementWiseNode.cpp::computeOutputExtents::19] Error Code 2: Internal Error (Assertion x.nbDims == y.nbDims failed. )
[02/03/2023-15:30:05] [TRT] [E] 2: [elementWiseNode.cpp::computeOutputExtents::19] Error Code 2: Internal Error (Assertion x.nbDims == y.nbDims failed. )
[02/03/2023-15:30:05] [TRT] [E] 4: [layers.cpp::validate::2443] Error Code 4: Internal Error (:1:ELEMENTWISE:GPU: elementwise inputs must have same dimensions or follow broadcast rules (input dimensions were [2,2] and [1]).)
I'll put up a fix shortly.