torch2trt
torch2trt copied to clipboard
assert(permutation[0] == 0) # cannot move batch dim
File "/home/qitan/bonito/bonito/nn.py", line 151, in forward return x.permute(*self.dims) File "/usr/local/lib/python3.8/dist-packages/torch2trt-0.3.0-py3.8.egg/torch2trt/torch2trt.py", line 300, in wrapper converter"converter" File "/usr/local/lib/python3.8/dist-packages/torch2trt-0.3.0-py3.8.egg/torch2trt/converters/permute.py", line 17, in convert_permute assert(permutation[0] == 0) # cannot move batch dim AssertionError
I believe this is a TRT issue; you can't permute the batch dimension in TRT (see here), whereas PyTorch doesn't have such a restriction, as TRT assumes the first dim is always the batch dim.
One workaround I use is to unsqueeze an extra dim to represent the batch dim, permute the rest of the tensor as originally intended, then squeeze out the extra dim afterwards.