torch2trt
torch2trt copied to clipboard
convert swinv2 in timm raise IndexError
When I convert the swinv2 model by torch2trt, an IndexError
raise as following:
File "/usr/local/lib/python3.8/dist-packages/timm/models/swin_transformer_v2.py", line 215, in forward
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
File "/usr/local/lib/python3.8/dist-packages/torch2trt-0.4.0-py3.8.egg/torch2trt/torch2trt.py", line 310, in wrapper
converter["converter"](ctx)
File "/usr/local/lib/python3.8/dist-packages/torch2trt-0.4.0-py3.8.egg/torch2trt/converters/Linear.py", line 8, in convert_Linear
input = ctx.method_args[0]
IndexError: tuple index out of range
line: /usr/local/lib/python3.8/dist-packages/timm/models/swin_transformer_v2.py:215 as following:
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
Why? Anyone can help me fix it?
Environment:
- transformers==4.19.2
- torch2trt==0.4.0
- torch==1.13.0a0+936e930
- timm==0.6.7
The model and implement code as following:
import timm
import torch
model = timm.create_model("swinv2_tiny_window8_256", pretrained=True)
model.eval().cuda()
from torch2trt import torch2trt
inp = torch.ones([8, 3, 256, 256])
input_data = inp.to("cuda:0")
build_cfg = {
"fp16_mode": True,
"min_shape": (1, 3, 256, 256),
"opt_shape": (1, 3, 256, 256),
"max_shape": (8, 3, 256, 256),
}
model_trt = torch2trt(model, [input_data], **build_cfg)