onnx-tensorrt icon indicating copy to clipboard operation
onnx-tensorrt copied to clipboard

Converting multi-input conv layer to tensorrt failed.

Open weizhiyi777 opened this issue 4 years ago • 14 comments

Hello guys,

I try converting a model (pytorch -> onnx -> tensorrt) with one multi-input conv layer. But it failed :( Here is the script of converting pytorch model to onnx model:

import torch
import torch.nn.functional as F

class my_conv_model(torch.nn.Module):

    def __init__(self):
        super(my_conv_model, self).__init__()

    def forward(self, input):
        kernel = torch.rand(16, 3, 1, 1)
        output = F.conv2d(input, kernel, stride=1)
        return output

if __name__ == "__main__":
    net = my_conv_model()
    input_tensor = torch.rand(1, 3, 1024, 1728)
    
    input_names, output_names = [ "input_onnx"], [ "output_onnx" ]
    torch.onnx.export(net, input_tensor, "test.onnx", verbose=True, input_names=input_names, output_names=output_names, opset_version=10)

I also use onnx python API to print some onnx model info:

output: "1"
op_type: "RandomUniform"
attribute {
  name: "shape"
  ints: 16
  ints: 3
  ints: 1
  ints: 1
  type: INTS
}

input: "input_onnx"
input: "1"
output: "output_onnx"
op_type: "Conv"
attribute {
  name: "dilations"
  ints: 1
  ints: 1
  type: INTS
}
attribute {
  name: "group"
  i: 1
  type: INT
}
attribute {
  name: "kernel_shape"
  ints: 1
  ints: 1
  type: INTS
}
attribute {
  name: "pads"
  ints: 0
  ints: 0
  ints: 0
  ints: 0
  type: INTS
}
attribute {
  name: "strides"
  ints: 1
  ints: 1
  type: INTS
}

When I used onnx2trt tool to convert this onnx model to tensorrt engine, I got the following error:

----------------------------------------------------------------
Input filename:   /root/data/Commonly_Used_Files/Model/solov2/onnx/solov2_test.onnx
ONNX IR version:  0.0.4
Opset version:    10
Producer name:    pytorch
Producer version: 1.3
Domain:           
Model version:    0
Doc string:       
----------------------------------------------------------------
Parsing model
Building TensorRT engine, FP16 available:1
    Max batch size:     32
    Max workspace size: 1024 MiB
[2021-01-04 03:20:29   ERROR] _0: kernel weights has count 0 but 48 was expected
[2021-01-04 03:20:29   ERROR] _0: count of 0 weights in kernel, but kernel dimensions (1,1) with 3 input channels, 16 output channels and 1 groups were specified. Expected Weights count is 3 * 1*1 * 16 / 1 = 48
[2021-01-04 03:20:29   ERROR] Layer _0 failed validation
[2021-01-04 03:20:29   ERROR] Network validation failed.

I have read the code and it seems onnx-tensorrt actually could support multi-input conv layer. Could you help to look at this issue?

What's more, the version info is as follows:

pytorch: 1.4.0 tensorrt: 7.1.2.8 onnx-tensorrt: 7.1.0

Thanks a lot!

Best regards, Wei Zhiyi

weizhiyi777 avatar Jan 04 '21 03:01 weizhiyi777

@weizhiyi777 No, you have bug in your code, and onnx2trt currently not support multi-input conv.

See my issue you can reproduce, once you exported right onnx model, you will got error like this:

[8] Assertion failed: ctx->network()->hasExplicitPrecision() && "TensorRT only supports multi-input conv for explicit precision QAT networks!"

lucasjinreal avatar Apr 25 '21 08:04 lucasjinreal

We only support multi-input convs for quantized networks. What is the use case of this conv?

kevinch-nv avatar May 10 '21 17:05 kevinch-nv

We only support multi-input convs for quantized networks. What is the use case of this conv?

Hi @kevinch-nv multi-input convs are used in SOLOv2 network, which is a instance semantic segmentation network. One of its innovation is using features to predict weights of convs. These weights are used for subsequent conv operations. Details could be found here:

[Paper] https://arxiv.org/abs/2003.10152 [Github Repository] https://github.com/WXinlong/SOLO

I think Tensorrt could support multi-input conv ops. It would be very helpful for onnx-tenssort to support that, because more and more networks start to predict conv weight.

weizhiyi777 avatar May 11 '21 01:05 weizhiyi777

@weizhiyi777 Some people able convert SOLOv2 to tensorrt: https://mp.weixin.qq.com/s/gk3Rq2kmZ159gZdYNGGvMA

lucasjinreal avatar May 11 '21 14:05 lucasjinreal

@jinfagang Thanks a lot! I saw this probelm was deal with by other solutions. I will try that.

weizhiyi777 avatar May 12 '21 03:05 weizhiyi777

I also meet this problem when I use F.conv2d. Can you tell me how to solve it? @jinfagang @weizhiyi777

Xiaoyw1998 avatar May 22 '21 15:05 Xiaoyw1998

@Xiaoyw1998 I didn't solve that yet but in this link maybe you could get some inspiration: https://mp.weixin.qq.com/s/gk3Rq2kmZ159gZdYNGGvMA. In this blog, the author use matrix multiplication to replace F.conv2d, which could be supported by onnx. I think you could try this solution if you have time.

weizhiyi777 avatar May 23 '21 06:05 weizhiyi777

@Xiaoyw1998 I didn't solve that yet but in this link maybe you could get some inspiration: https://mp.weixin.qq.com/s/gk3Rq2kmZ159gZdYNGGvMA. In this blog, the author use matrix multiplication to replace F.conv2d, which could be supported by onnx. I think you could try this solution if you have time.

Thank you for your reply. But I want to use 7x7 conv, it cannot be replaced by matrix multiplication.

Xiaoyw1998 avatar May 23 '21 10:05 Xiaoyw1998

@weizhiyi777 @jinfagang @Xiaoyw1998 Have you solved the problem with Multi Input Convolution?

furkancoskun avatar Aug 24 '21 07:08 furkancoskun

@weizhiyi777 @jinfagang @Xiaoyw1998 Have you solved the problem with Multi Input Convolution?

No

Xiaoyw1998 avatar Aug 24 '21 07:08 Xiaoyw1998

This is still a known limitation inside TensorRT. We are planning to support this in a future release of TensorRT, in the meantime it's recommended to export your models with static conv weights if possible.

kevinch-nv avatar Mar 21 '22 18:03 kevinch-nv

in my situation, using torch.nn.conv2d instead of F.conv2d . It seems that when using nn.conv2d, you need to initialize in_channels and out_channels,the onnxparser knows how to deal with it. But F.conv2d gives the kernel weights by hands.

monsterlyg avatar Mar 31 '22 07:03 monsterlyg

in my situation, using torch.nn.conv2d instead of F.conv2d . It seems that when using nn.conv2d, you need to initialize in_channels and out_channels,the onnxparser knows how to deal with it. But F.conv2d gives the kernel weights by hands.

cause data input and kernel wieghts are two inputs

monsterlyg avatar Mar 31 '22 07:03 monsterlyg

This dynamic weights feature will be supported in next release. thanks!

ttyio avatar Jul 13 '22 02:07 ttyio

Is there any update on the convolution with dynamic weight feature?

erfaneshrati avatar Oct 18 '22 05:10 erfaneshrati

+1 on the convolution with dynamic weights @ttyio

rocco-haro avatar Oct 18 '22 06:10 rocco-haro

Hello @kevinch-nv. First of all, thank you very much for contributing to the open-source community!

Quick question for you: is there a branch we could pull to use this feature for "early" access instead of waiting for the next release? If not, when do you expect the next release to ship?

rocco-haro avatar Oct 19 '22 16:10 rocco-haro

+1 problem with convolution dynamic weights. I think write a custom operator in tensorrt plugin may solve this problem.

wuyunnben avatar Oct 25 '22 08:10 wuyunnben

@erfaneshrati @rocco-haro @wuyunnben Thank you for your patient. The release will happen in one month.

zhenhuaw-me avatar Oct 25 '22 09:10 zhenhuaw-me

+1 waiting for dynamic weight convolution to accelerate inference time ^-^

mangoyuan avatar Nov 08 '22 03:11 mangoyuan

Guys, thank you for your patient! Please check the latest TensorRT 8.5.1 release. You can find the header description here: https://github.com/NVIDIA/TensorRT/blob/main/include/NvInfer.h#L1442

zhenhuaw-me avatar Nov 08 '22 04:11 zhenhuaw-me

Thank you @zhenhuaw-me !

rocco-haro avatar Nov 08 '22 23:11 rocco-haro

Closing this issue since this feature has been released. Feel free to reopen if any further questions. Thanks!

zhenhuaw-me avatar Nov 09 '22 12:11 zhenhuaw-me

So now how do I convert F.conv2d to tensorrt?

lansfair avatar Jan 04 '23 09:01 lansfair

@lansfair If you are looking for convert from PyTorch to TensorRT directly, you might try https://github.com/pytorch/TensorRT; otherwise you can export PyTorch to ONNX and let TensorRT loads the ONNX model.

zhenhuaw-me avatar Jan 11 '23 10:01 zhenhuaw-me