torch2trt icon indicating copy to clipboard operation
torch2trt copied to clipboard

convert swinv2 in timm raise IndexError

Open deepindeed2022 opened this issue 2 years ago • 0 comments

When I convert the swinv2 model by torch2trt, an IndexError raise as following:

  File "/usr/local/lib/python3.8/dist-packages/timm/models/swin_transformer_v2.py", line 215, in forward
    qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
  File "/usr/local/lib/python3.8/dist-packages/torch2trt-0.4.0-py3.8.egg/torch2trt/torch2trt.py", line 310, in wrapper
    converter["converter"](ctx)
  File "/usr/local/lib/python3.8/dist-packages/torch2trt-0.4.0-py3.8.egg/torch2trt/converters/Linear.py", line 8, in convert_Linear
    input = ctx.method_args[0]
IndexError: tuple index out of range

line: /usr/local/lib/python3.8/dist-packages/timm/models/swin_transformer_v2.py:215 as following: qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)

Why? Anyone can help me fix it?

Environment:

  • transformers==4.19.2
  • torch2trt==0.4.0
  • torch==1.13.0a0+936e930
  • timm==0.6.7

The model and implement code as following:

import timm
import torch
model = timm.create_model("swinv2_tiny_window8_256", pretrained=True)
model.eval().cuda()
from torch2trt import torch2trt
inp = torch.ones([8, 3, 256, 256])
input_data = inp.to("cuda:0")
build_cfg = {
    "fp16_mode": True,
    "min_shape": (1, 3, 256, 256),
    "opt_shape": (1, 3, 256, 256),
    "max_shape": (8, 3, 256, 256),
}
model_trt = torch2trt(model, [input_data], **build_cfg)

deepindeed2022 avatar Feb 10 '23 09:02 deepindeed2022