torch2trt
torch2trt copied to clipboard
[Feature Request] TensorRT does not automatically promote dtypes when comparing inputs
TensorRT requires manual promotion of input dtypes when operating on two or more tensors of different dtypes.
Observe the following:
import logging
import tensorrt
import torch
import torch2trt
from typing import List
logging.basicConfig(level=logging.INFO)
torch.manual_seed(0)
class Module(torch.nn.Module):
def forward(self, t: torch.Tensor) -> List[torch.Tensor]:
h, w = t.shape
return [t < h]
if __name__ == "__main__":
tensor = torch.rand(3, 3).cuda()
model = Module().cuda().eval()
out = model(tensor).pop()
print(f"Out {out}")
model_trt = torch2trt.torch2trt(
model, [tensor], log_level=tensorrt.Logger.INFO, min_shapes=[(1,1)], max_shapes=[(10,10)]
)
out_trt = model_trt(tensor).pop()
print(f"Out TRT {out_trt}")
assert torch.allclose(out, out_trt), "Not all close!"
print("All close!")
tensor = torch.rand(3,3).cuda()
out = model(tensor).pop()
out_trt = model_trt(tensor).pop()
assert torch.allclose(out, out_trt), "Not all close!"
print("All close!")
Outputs the following:
Out tensor([[True, True, True],
[True, True, True],
[True, True, True]], device='cuda:0')
[02/14/2023-19:47:17] [TRT] [I] [MemUsageChange] Init CUDA: CPU +3, GPU +0, now: CPU 696, GPU 354 (MiB)
[02/14/2023-19:47:19] [TRT] [I] [MemUsageChange] Init builder kernel library: CPU +447, GPU +120, now: CPU 1197, GPU 474 (MiB)
[02/14/2023-19:47:19] [TRT] [W] IElementWiseLayer with inputs input_0 and (Unnamed Layer* 3) [Shuffle]_output: first input has type Float but second input has type Int32.
[02/14/2023-19:47:19] [TRT] [W] Tensor DataType is determined at build time for tensors not marked as input or output.
[02/14/2023-19:47:19] [TRT] [E] 4: [elementWiseLayer.cpp::validate::34] Error Code 4: Internal Error (:1:ELEMENTWISE:GPU: operation LESS has incompatible input types Float and Int32)
[02/14/2023-19:47:19] [TRT] [E] 4: [elementWiseLayer.cpp::validate::34] Error Code 4: Internal Error (:1:ELEMENTWISE:GPU: operation LESS has incompatible input types Float and Int32)
Traceback (most recent call last):
File "/home/chaoz/workspace/scratch/torch2trt/promote-type.py", line 29, in <module>
out_trt = model_trt(tensor).pop()
File "/home/chaoz/.anaconda3/envs/torch2trt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/chaoz/.anaconda3/envs/torch2trt/lib/python3.10/site-packages/torch2trt-0.4.0-py3.10.egg/torch2trt/torch2trt.py", line 635, in forward
idx = self.engine.get_binding_index(input_name)
AttributeError: 'NoneType' object has no attribute 'get_binding_index'
In this case, we are trying to compare a float and an int32, which fails in TensorRT (but is okay in PyTorch). We should consider adding automatic type promotion during conversion as a developer QoL improvement (instead of relying on the user manually promoting).
I'll have a PR up shortly.