torch2trt
torch2trt copied to clipboard
[torch2trt/torch2trt.py] Address issue #848: Add automatic type promotion when comparing tensors of different dtypes.
Addresses issue #848.
Depends on PR #847.
This PR adds functions for automatic type promotion, by (converting to and) relying on the torch.promote_type function.
This functionality is added to the comparison converters.
Running the script from #848 :
import logging
import tensorrt
import torch
import torch2trt
from typing import List
logging.basicConfig(level=logging.INFO)
torch.manual_seed(0)
class Module(torch.nn.Module):
def forward(self, t: torch.Tensor) -> List[torch.Tensor]:
h, w = t.shape
return [t < h]
if __name__ == "__main__":
tensor = torch.rand(3, 3).cuda()
model = Module().cuda().eval()
out = model(tensor).pop()
print(f"Out {out}")
model_trt = torch2trt.torch2trt(
model, [tensor], log_level=tensorrt.Logger.INFO, min_shapes=[(1,1)], max_shapes=[(10,10)]
)
out_trt = model_trt(tensor).pop()
print(f"Out TRT {out_trt}")
assert torch.allclose(out, out_trt), "Not all close!"
print("All close!")
tensor = torch.rand(3,3).cuda()
out = model(tensor).pop()
out_trt = model_trt(tensor).pop()
assert torch.allclose(out, out_trt), "Not all close!"
print("All close!")
Now correctly outputs the following:
Out tensor([[True, True, True],
[True, True, True],
[True, True, True]], device='cuda:0')
[02/14/2023-19:56:43] [TRT] [I] [MemUsageChange] Init CUDA: CPU +3, GPU +0, now: CPU 696, GPU 354 (MiB)
[02/14/2023-19:56:45] [TRT] [I] [MemUsageChange] Init builder kernel library: CPU +447, GPU +120, now: CPU 1197, GPU 474 (MiB)
[02/14/2023-19:56:45] [TRT] [W] Tensor DataType is determined at build time for tensors not marked as input or output.
[02/14/2023-19:56:45] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +839, GPU +338, now: CPU 2036, GPU 812 (MiB)
[02/14/2023-19:56:46] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +129, GPU +58, now: CPU 2165, GPU 870 (MiB)
[02/14/2023-19:56:46] [TRT] [W] TensorRT was linked against cuDNN 8.6.0 but loaded cuDNN 8.4.0
[02/14/2023-19:56:46] [TRT] [I] Local timing cache in use. Profiling results in this builder pass will not be stored.
[02/14/2023-19:56:46] [TRT] [W] Myelin graph with multiple dynamic values may have poor performance if they differ. Dynamic values are:
[02/14/2023-19:56:46] [TRT] [W] (# 1 (SHAPE input_0))
[02/14/2023-19:56:46] [TRT] [W] (# 0 (SHAPE input_0))
[02/14/2023-19:56:47] [TRT] [I] Total Activation Memory: 33556992
[02/14/2023-19:56:47] [TRT] [I] Detected 1 inputs and 1 output network tensors.
[02/14/2023-19:56:47] [TRT] [W] Myelin graph with multiple dynamic values may have poor performance if they differ. Dynamic values are:
[02/14/2023-19:56:47] [TRT] [W] (# 1 (SHAPE input_0))
[02/14/2023-19:56:47] [TRT] [W] (# 0 (SHAPE input_0))
[02/14/2023-19:56:47] [TRT] [I] Total Host Persistent Memory: 160
[02/14/2023-19:56:47] [TRT] [I] Total Device Persistent Memory: 0
[02/14/2023-19:56:47] [TRT] [I] Total Scratch Memory: 1024
[02/14/2023-19:56:47] [TRT] [I] [MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 0 MiB, GPU 4 MiB
[02/14/2023-19:56:47] [TRT] [I] [BlockAssignment] Started assigning block shifts. This will take 6 steps to complete.
[02/14/2023-19:56:47] [TRT] [I] [BlockAssignment] Algorithm ShiftNTopDown took 0.02852ms to assign 5 blocks to 6 nodes requiring 3072 bytes.
[02/14/2023-19:56:47] [TRT] [I] Total Activation Memory: 3072
[02/14/2023-19:56:47] [TRT] [I] [MemUsageChange] TensorRT-managed allocation in building engine: CPU +0, GPU +4, now: CPU 0, GPU 4 (MiB)
[02/14/2023-19:56:47] [TRT] [I] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +0, now: CPU 0, GPU 4 (MiB)
Out TRT tensor([[True, True, True],
[True, True, True],
[True, True, True]], device='cuda:0')
All close!
All close!
Looks good, my only nit is that the name trt_tensor_to_dtype reads very similar to the other dtype utilities in that module, but the function is different. If there's a more distinguished name you can think of, it might be nice to go with that.