torch2trt
torch2trt copied to clipboard
[Bug] Comparing a tensor with an IntWrapper results in error `TypeError: add_constant(): incompatible function arguments.`
Currently, when we compare a scalar with a tensor, add_missing_trt_tensors will try to convert the scalar to a tensor so that we can then add it as a TRT constant. However, this logic does not yet account for the fact that the scalar could be an IntWrapper, in which case we would want to use the ._trt value already present with the IntWrapper. Instead, it attempts to use the IntWrapper's comparators with the given tensor, which generates nonsensical output that is then passed to add_constant from the original logic.
This can be observed by running the following script:
import logging
import tensorrt
import torch
import torch2trt
from typing import List
logging.basicConfig(level=logging.INFO)
torch.manual_seed(0)
class Module(torch.nn.Module):
def forward(self, t: torch.Tensor) -> List[torch.Tensor]:
h, w = t.shape
return [t < float(h)]
if __name__ == "__main__":
tensor = torch.rand(3, 3).cuda()
model = Module().cuda().eval()
out = model(tensor).pop()
print(f"Out {out}")
model_trt = torch2trt.torch2trt(
model, [tensor], log_level=tensorrt.Logger.INFO, min_shapes=[(1,1)], max_shapes=[(10,10)]
)
out_trt = model_trt(tensor).pop()
print(f"Out TRT {out_trt}")
assert torch.allclose(out, out_trt), "Not all close!"
print("All close!")
tensor = torch.rand(3,3).cuda()
out = model(tensor).pop()
out_trt = model_trt(tensor).pop()
assert torch.allclose(out, out_trt), "Not all close!"
print("All close!")
Which outputs the following:
TypeError: add_constant(): incompatible function arguments. The following argument types are supported:
1. (self: tensorrt.tensorrt.INetworkDefinition, shape: tensorrt.tensorrt.Dims, weights: tensorrt.tensorrt.Weights) -> tensorrt.tensorrt.IConstantLayer
Invoked with: <tensorrt.tensorrt.INetworkDefinition object at 0x7fec1b1c80b0>, (1,), 3
I'll put up a fix shortly.