torch2trt
torch2trt copied to clipboard
[Bug] `torch.squeeze` removes non-singular dimensions
Currently, the torch.squeeze converter removes dimensions without checking if it's valid to do so; ie. dimensions should only be removed if they are singular (ie. 1). Otherwise, removing dimensions will cause volume changes.
import logging
import tensorrt
import torch
import torch2trt
logging.basicConfig(level=logging.INFO)
torch.manual_seed(0)
class SqueezeModule(torch.nn.Module):
def forward(self, t: torch.Tensor):
return t.squeeze(-1)
if __name__ == "__main__":
tensor = torch.rand(2, 2, 3, 4).cuda()
model = SqueezeModule().eval().cuda()
model(tensor)
model_trt = torch2trt.torch2trt(
model,
[tensor],
min_shapes = [(1, 1, 1, 1)],
max_shapes = [(10, 10, 10, 10)]
)
tensor = torch.rand(1, 1, 2, 3).cuda()
out = model(tensor)
out_trt = model_trt(tensor)
assert torch.allclose(out, out_trt), f"Not all close\n{out}\n{out_trt}"
print("All close!")
Outputs the following:
[02/08/2023-16:19:31] [TRT] [E] 4: [shapeCompiler.cpp::evaluateShapeChecks::911] Error Code 4: Internal Error (kOPT values for profile 0 violate shape constraints: reshape would change volume. IShuffleLayer :5:SHUFFLE:GPU: reshaping failed for tensor: input_0)
Traceback (most recent call last):
File "/home/chaoz/workspace/scratch/torch2trt/squeeze.py", line 32, in <module>
out_trt = model_trt(tensor)
File "/home/chaoz/.anaconda3/envs/torch2trt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/chaoz/.anaconda3/envs/torch2trt/lib/python3.10/site-packages/torch2trt-0.4.0-py3.10.egg/torch2trt/torch2trt.py", line 613, in forward
idx = self.engine.get_binding_index(input_name)
AttributeError: 'NoneType' object has no attribute 'get_binding_index'
I'll take a look at this issue if needed, but I think this really only comes up with dynamic shapes. In these situations, the user usually knows a priori that the dimension is safe to remove anyway (otherwise we'd have to deal with variable dimensional tensors).