Paddle2ONNX icon indicating copy to clipboard operation
Paddle2ONNX copied to clipboard

动态 batchsize 下的 resize 算子无法通过 OnnxRuntime 推理

Open hebangwen opened this issue 3 years ago • 7 comments

问题描述

我想要把 PaddleDetection 中的 tinypose 用 onnxruntime 部署,通过 README.md 导出的模型可以通过 ort 执行,但是经过 onnxsim 后的模型却不可以。我发现原因是 onnxsim 会保留 Paddle2ONNX 中的 scale 属性,但是由于 tinypose 的 backbone litehrnet 用的是 size 来指定缩放后的张量宽高的,而没有用 scale ,所以 scale 属性会变成空。在 interpolate.cc 中,当 size 为空时,会给参数一个空的数组占位。这个占位符就导致了动态 shape 下的报错。

测试代码结果:

batchsize use_onnxsim 推理成功
1 True / False ✔️
-1 True
-1 False ✔️

请问能否通过修改 Paddle2ONNX 支持动态 shape 下 onnxsim 之后的 resize 算子也能成功推理?

更多信息 :

  • 用于部署的推理引擎: OnnxRuntime
  • 为什么需要转换为ONNX格式:部署 tinypose
  • Paddle2ONNX版本: 0.9.8
  • 你的联系方式(Email/Wechat/Phone): [email protected]

报错截图

image

其他信息

import numpy as np
import onnx
import onnxruntime as ort
import onnxsim
import paddle
import paddle.static
import paddle.onnx
import paddle.nn as nn
import paddle.nn.functional as F


class MyInterpolate(nn.Layer):
    def __init__(self, target_size=()):
        super().__init__()
        self.target_size = target_size
    
    def forward(self, x):
        return F.interpolate(
            x,
            size=self.target_size,
            mode='bilinear',
            align_corners=True
        )


if __name__ == "__main__":
    origin_size, resized_size, channels = (64, 64), (128, 128), 3
    model_name = "paddle_resize"
    use_onnxsim = True
    batchsize = -1

    resize_op = MyInterpolate(resized_size)
    x_spec = paddle.static.InputSpec(shape=(batchsize, channels, *origin_size), dtype='float32', name='input')
    paddle.onnx.export(
        resize_op,
        model_name,
        input_spec=[x_spec],
        opset_version=11
    )

    random_input = np.random.randn(1, channels, *origin_size).astype(np.float32)
    inp = {"input": random_input}

    if use_onnxsim:
        model_simp, check = onnxsim.simplify(
            model=model_name+".onnx",
            test_input_shapes={"input": random_input.shape},
            input_data=inp)
        assert check, "Simplified ONNX model could not be validated"
        model_path = f"{model_name}-sim.onnx"
        onnx.save_model(model_simp, model_path)
        print(f"Simplified ONNX model is saved to {model_path}.")
    else:
        model_path = f"{model_name}.onnx"

    session = ort.InferenceSession(model_path)
    result = session.run(None, inp)
    print(result[0].shape)

hebangwen avatar Aug 25 '22 14:08 hebangwen

测试代码和模型:resize_op_p2o.zip,tinypose 是通过 PaddleDetection 的教程导出的。

hebangwen avatar Aug 25 '22 14:08 hebangwen

这里是需要paddle2onnx给一个scale值么?

另外这个bug是属于onnxsim还是paddle2onnx呢?

jiangjiajun avatar Aug 25 '22 14:08 jiangjiajun

我觉得不是给一个 scale 值,而是在有 size 的时候,删除 scale;或者按照 size 的值,给 scale 赋值,然后删掉 size (因为ort要求同时只能有一个可用的 scale 或者 size)。

其实我也在想这个 bug 到底是算 onnxsim 的还是算 paddle2onnx 的,因为 paddle2onnx 导出的 onnx 模型其实可以直接运行,但是经过 onnxsim 之后不行。但是 onnxsim 多出的 scale 属性又是从 paddle2onnx 导出的模型里继承的。

hebangwen avatar Aug 25 '22 14:08 hebangwen

我看了上面的模型, paddle的模型中,使用-1表示batch维度。 但onnxsim可能还没支持将-1识别为动态维度,因此在对动态维度优化时,视为了静态维度。

当前原图如下 image

优化后,生成的size中包含了-1值,运行出错 image

跟onnxsim的维护者也进行了沟通, 会尽快这两天更新代码支持哈

jiangjiajun avatar Aug 26 '22 02:08 jiangjiajun

好的,感谢回复👍

hebangwen avatar Aug 26 '22 02:08 hebangwen

@BangwenHe onnx/optimizer的代码已在如下commit更新 https://github.com/onnx/optimizer/commit/05c54e924dd377c923b3d1b38e7cc6dbe3b5b071 预计会在明天完成新版本的发版,更新onnxsim版本即可。

另外也感谢 @daquexian 的支持 ❤️

jiangjiajun avatar Aug 26 '22 10:08 jiangjiajun

def forward(self, input_ids, token_type_ids, pos_ids, att_mask):
    sequence_output, pooled_output = self.encoder(
        input_ids=input_ids,
        token_type_ids=token_type_ids,
        position_ids=pos_ids,
        attention_mask=att_mask)
    sequence_output = paddle.reshape(sequence_output,[-1,512,12,26])
    eca_output = self.eca(sequence_output) 
    eca_output = paddle.reshape(eca_output,[-1,512,312])
    start_logits = self.linear_start(eca_output )
    start_logits = paddle.squeeze(start_logits, -1)
    start_prob = self.sigmoid(start_logits)
    end_logits = self.linear_end(eca_output )
    end_logits = paddle.squeeze(end_logits, -1)
    end_prob = self.sigmoid(end_logits)
    return start_prob, end_prob

针对UIE的改动,模型转静态图正常,但是在加载静态图infer时出现了以下报错(楼上-1为静态维度的问题导致不能广播): onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Non-zero status code returned while running Expand node. Name:'p2o.Expand.1' Status Message: p2o.Expand.1: right operand cannot broadcast on dim 0 LeftShape: {1,512,1,1}, RightShape: {-1,512,12,26}

更新了 onnxsim --> Successfully installed commonmark-0.9.1 onnx-simplifier-0.4.8 rich-12.5.1 通过onnxsim保存模型之后加载还是报上述错误! 期待回复!

Affectionate-0 avatar Sep 15 '22 09:09 Affectionate-0