PaddleHub icon indicating copy to clipboard operation
PaddleHub copied to clipboard

请问paddlehub中的模型FCN_HRNet_W18_Face_Seg怎么转成onnx格式

Open 471417367 opened this issue 2 years ago • 8 comments

paddlepaddle-gpu==0.0.0.post110 paddlehub==2.1.1

以下方式都失败了: 方式1: import paddlehub as hub

human_seg = hub.Module(name="FCN_HRNet_W18_Face_Seg") human_seg.save_inference_model(dirname="output/", model_filename="output/inference.pdmodel", params_filename="output/inference.pdiparams") 报错:RuntimeError: Module FCN_HRNet_W18_Face_Seg lacks input_spec, please specify it when calling save_inference_model.

方式2: import paddlehub as hub import paddle

human_seg = hub.Module(name="FCN_HRNet_W18_Face_Seg") human_seg.save_inference_model(dirname="output/", model_filename="output/inference.pdmodel", params_filename="output/inference.pdiparams", input_spec=[None, 3, 384, 384]) 报错:ValueError: The decorated function forward requires 0 arguments: [], but received 4 with (None, 3, 384, 384).

方式3: import paddlehub as hub import paddle

human_seg = hub.Module(name="FCN_HRNet_W18_Face_Seg") input_spec = paddle.static.InputSpec(shape=[None, 3, 384, 384], dtype='float32', name='FCN_HRNet_W18_Face_Seg') paddle.onnx.export(human_seg, 'model', input_spec=[input_spec], opset_version=11) 报错:ValueError: The decorated function forward requires 0 arguments: [], but received 1 with (InputSpec(shape=(-1, 3, 384, 384), dtype=paddle.float32, name=FCN_HRNet_W18_Face_Seg),).

另外也尝试过网上能找到的一些模型加载后再保存的方式都失败,请问有转换成功的可以分享下吗? paddle2onnx的转换对只有一个seg_model_384.pdparams模型文件的怎么转换呢? PaddleSeg中利用export.py转模型需要传configs的,yml文件,没有FCN_HRNet_W18_Face_Seg的。

471417367 avatar Mar 15 '22 02:03 471417367

class FCN_HRNet_W18_Face_Seg(nn.Layer): def init(self): super(FCN_HRNet_W18_Face_Seg, self).init() # 加载分割模型 self.seg = FCN(num_classes=2, backbone=HRNet_W18())

    # 加载模型参数
    state_dict = paddle.load(os.path.join(self.directory, 'seg_model_384.pdparams'))
    self.seg.set_state_dict(state_dict)

    # 设置模型为评估模式
    self.seg.eval()
    # ============================================================
    input_spec = paddle.static.InputSpec(shape=[None, 3, 384, 384], dtype='float32', name='x')
    paddle.jit.save(self.seg, 'output', input_spec=[input_spec, ])

用上述方法可以成功得到output.pdiparams,output.pdiparams.info,output.pdmodel三个模型文件。 再使用:paddle2onnx --model_dir model --model_filename output.pdmodel --params_filename output.pdiparams --opset_version 11 --save_file seg_model_384.onnx

报错: Traceback (most recent call last): File "/usr/local/bin/paddle2onnx", line 8, in sys.exit(main()) File "/usr/local/lib/python3.6/dist-packages/paddle2onnx/command.py", line 195, in main input_shape_dict=input_shape_dict) File "/usr/local/lib/python3.6/dist-packages/paddle2onnx/command.py", line 159, in program2onnx operator_export_type=operator_export_type) File "/usr/local/lib/python3.6/dist-packages/paddle2onnx/convert.py", line 88, in program2onnx auto_update_opset) File "/usr/local/lib/python3.6/dist-packages/paddle2onnx/convert.py", line 36, in export_onnx auto_update_opset) File "/usr/local/lib/python3.6/dist-packages/paddle2onnx/graph/onnx_graph.py", line 258, in build auto_update_opset=auto_update_opset) File "/usr/local/lib/python3.6/dist-packages/paddle2onnx/graph/onnx_graph.py", line 85, in init self.update_opset_version() File "/usr/local/lib/python3.6/dist-packages/paddle2onnx/graph/onnx_graph.py", line 203, in update_opset_version node_map, self.opset_version) File "/usr/local/lib/python3.6/dist-packages/paddle2onnx/op_mapper/op_mapper.py", line 129, in get_recommend_opset_version node_map, opset_version, True) File "/usr/local/lib/python3.6/dist-packages/paddle2onnx/op_mapper/op_mapper.py", line 174, in check_support_status raise NotImplementedError(error_info) NotImplementedError: There's 1 ops are not supported yet =========== sync_batch_norm ===========

471417367 avatar Mar 15 '22 03:03 471417367

class FCN_HRNet_W18_Face_Seg(nn.Layer):
    def init(self):
        super(FCN_HRNet_W18_Face_Seg, self).init()
        # 加载分割模型
        self.seg = FCN(num_classes=2, backbone=HRNet_W18())

        # 加载模型参数
        state_dict = paddle.load(os.path.join(self.directory, 'seg_model_384.pdparams'))
        self.seg.set_state_dict(state_dict)

        # 设置模型为评估模式
        self.seg.eval()
        # ============================================================
        input_spec = paddle.static.InputSpec(shape=[None, 3, 384, 384], dtype='float32', name='x')
        paddle.jit.save(self.seg, 'output', input_spec=[input_spec, ])

用上述方法可以成功得到output.pdiparams,output.pdiparams.info,output.pdmodel三个模型文件。 再使用:paddle2onnx --model_dir model --model_filename output.pdmodel --params_filename output.pdiparams --opset_version 11 --save_file seg_model_384.onnx

报错:

你好,你这里应该是模型动转静成功,但是导出成onnx格式失败了。 失败的原因是paddle2onnx不支持sync_batch_norm这个op。

KPatr1ck avatar Mar 15 '22 08:03 KPatr1ck

失败的原因是paddle2onnx不支持sync_batch_norm这个op。

sync_batch_norm在导出时需要转成batch_norm,你需要指定为batch_norm后再导出。 https://github.com/PaddlePaddle/PaddleHub/blob/f624ce5d44c056187ce012cd035ee4cd8162807e/modules/image/semantic_segmentation/FCN_HRNet_W18_Face_Seg/model/layers.py#L23

KPatr1ck avatar Mar 15 '22 09:03 KPatr1ck

@KPatr1ck 非常感谢,已经成功。 就是在动转静模型的时候改: def SyncBatchNorm(*args, **kwargs): """In cpu environment nn.SyncBatchNorm does not have kernel so use nn.BatchNorm instead""" # if paddle.get_device() == 'cpu': return nn.BatchNorm(*args, **kwargs) # else: # return nn.SyncBatchNorm(*args, **kwargs)

然后再转onnx: paddle2onnx --model_dir model --model_filename output.pdmodel --params_filename output.pdiparams --opset_version 11 --save_file FCN_HRNet_W18_Face_Seg.onnx

471417367 avatar Mar 16 '22 01:03 471417367

你好,请问下paddlehub里面的模型可以正常转onnx吗?我没有看明白你后面是怎么转成功的,我的微信:hsliuyl

hsliuyl avatar Jun 17 '22 09:06 hsliuyl

我也只尝试过几个模型,用 paddle2onnx,有些模型的算子需要调整,如 nn.SyncBatchNorm 改回 nn.BatchNorm。

hsliuyl @.***> 于2022年6月17日周五 17:15写道:

你好,请问下paddlehub里面的模型可以正常转onnx吗?我没有看明白你后面是怎么转成功的,我的微信:hsliuyl

— Reply to this email directly, view it on GitHub https://github.com/PaddlePaddle/PaddleHub/issues/1809#issuecomment-1158674852, or unsubscribe https://github.com/notifications/unsubscribe-auth/ALQZU4X2KOBSMQ4XWPMLIFDVPQ63TANCNFSM5QXIXD4A . You are receiving this because you authored the thread.Message ID: @.***>

471417367 avatar Jun 20 '22 01:06 471417367

我也只尝试过几个模型,用 paddle2onnx,有些模型的算子需要调整,如 nn.SyncBatchNorm 改回 nn.BatchNorm。 hsliuyl @.> 于2022年6月17日周五 17:15写道: 你好,请问下paddlehub里面的模型可以正常转onnx吗?我没有看明白你后面是怎么转成功的,我的微信:hsliuyl — Reply to this email directly, view it on GitHub <#1809 (comment)>, or unsubscribe https://github.com/notifications/unsubscribe-auth/ALQZU4X2KOBSMQ4XWPMLIFDVPQ63TANCNFSM5QXIXD4A . You are receiving this because you authored the thread.Message ID: @.>

你意思是修改模型的算子吗?用原始模型是不是需要重新训练?

hsliuyl avatar Jun 20 '22 03:06 hsliuyl

不需要训练模型,hub里面的模型有些好像也不支持继续训练,加载模型的网络结构里面,该一下不支持的算子。

hsliuyl @.***> 于2022年6月20日周一 11:01写道:

我也只尝试过几个模型,用 paddle2onnx,有些模型的算子需要调整,如 nn.SyncBatchNorm 改回 nn.BatchNorm。 hsliuyl @.

> 于2022年6月17日周五 17:15写道: … <#m_-6429914012475419235_> 你好,请问下paddlehub里面的模型可以正常转onnx吗?我没有看明白你后面是怎么转成功的,我的微信:hsliuyl — Reply to this email directly, view it on GitHub <#1809 (comment) https://github.com/PaddlePaddle/PaddleHub/issues/1809#issuecomment-1158674852>, or unsubscribe https://github.com/notifications/unsubscribe-auth/ALQZU4X2KOBSMQ4XWPMLIFDVPQ63TANCNFSM5QXIXD4A https://github.com/notifications/unsubscribe-auth/ALQZU4X2KOBSMQ4XWPMLIFDVPQ63TANCNFSM5QXIXD4A . You are receiving this because you authored the thread.Message ID: @.>

你意思是修改模型的算子吗?用原始模型是不是需要重新训练?

— Reply to this email directly, view it on GitHub https://github.com/PaddlePaddle/PaddleHub/issues/1809#issuecomment-1159913378, or unsubscribe https://github.com/notifications/unsubscribe-auth/ALQZU4Q23IYIRP7TWIE7763VP7NJ7ANCNFSM5QXIXD4A . You are receiving this because you authored the thread.Message ID: @.***>

471417367 avatar Aug 01 '22 03:08 471417367