mmsegmentation icon indicating copy to clipboard operation
mmsegmentation copied to clipboard

Export ONNX model error

Open Aspirinkb opened this issue 2 years ago • 5 comments

Thanks for your error report and we appreciate it a lot.

Checklist

  1. I have searched related issues but cannot get the expected help.
  2. The bug has not been fixed in the latest version.

Describe the bug Can not convert trained model into onnx format when using tools/pytorch2onnx.py. When run torch.onnx.export, there is a error: TypeError: forward() got multiple values for argument 'img_metas'. I read the code, the model's forward function is wrapped and img_metas is set. I can not figure out why arising this error while just set img_metas once.

Reproduction

  1. What command or script did you run?
python tools/pytorch2onnx.py \
./configs/convnext_spot/upernet_convnext_base_fp16_960x960_10k_spot2.py \
--checkpoint /general-user/frank/spot/data/spot/upernet_convnext_base_fp16_960x960_10k_spot2/best_mDice_iter_8100.pth \
--output-file /general-user/frank/spot/mmsegmentation/work_dirs/upernet_convnext_base_fp16_960x960_10k_spot2/model.onnx \
--input-img /general-user/frank/spot/data/spot/diff_images2/val/IMG_20220707_162730_1.jpg \
--show \
--verify \
--dynamic-export \
--cfg-options \
  model.test_cfg.mode="whole"
  1. Did you make any modifications on the code or config? Did you understand what you have modified? No

  2. What dataset did you use? My own

Environment

  1. Please run python mmseg/utils/collect_env.py to collect necessary environment information and paste it here.
  2. You may add addition that may be helpful for locating the problem, such as
    • How you installed PyTorch [e.g., pip, conda, source]
    • Other environment variables that may be related (such as $PATH, $LD_LIBRARY_PATH, $PYTHONPATH, etc.)

Error traceback

If applicable, paste the error trackback here.

Traceback (most recent call last):
  File "tools/pytorch2onnx.py", line 387, in <module>
    pytorch2onnx(
  File "tools/pytorch2onnx.py", line 196, in pytorch2onnx
    torch.onnx.export(
  File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/onnx/__init__.py", line 350, in export
    return utils.export(
  File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/onnx/utils.py", line 163, in export
    _export(
  File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/onnx/utils.py", line 1074, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/onnx/utils.py", line 727, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
  File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/onnx/utils.py", line 602, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/onnx/utils.py", line 517, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
  File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/jit/_trace.py", line 1175, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/jit/_trace.py", line 127, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/jit/_trace.py", line 118, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1118, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/general-user/frank/spot/env/lib/python3.8/site-packages/mmcv/runner/fp16_utils.py", line 118, in new_func
    return old_func(*args, **kwargs)
TypeError: forward() got multiple values for argument 'img_metas'

Bug fix

If you have already identified the reason, you can provide the information here. If you are willing to create a PR to fix it, please also leave a comment here and that would be much appreciated!

Aspirinkb avatar Aug 05 '22 12:08 Aspirinkb

torch: 1.12.0+cu116

Aspirinkb avatar Aug 06 '22 07:08 Aspirinkb

torch: 1.12.0+cu116

hello, have you solved the problem ?i meet the same problem when i run pytorch2onnx.py

feiqiu-cyber avatar Aug 07 '22 13:08 feiqiu-cyber

torch: 1.12.0+cu116

hello, have you solved the problem ?i meet the same problem when i run pytorch2onnx.py

No. I guess this is a problem of MMSeg to support onnx export.

Aspirinkb avatar Aug 08 '22 03:08 Aspirinkb

Hi @Aspirinkb, I can run the following command successfully with torch1.9 on the CPU.

python tools/pytorch2onnx.py configs/convnext/upernet_convnext_base_fp16_512x512_160k_ade20k.py --checkpoint checkpoints/upernet_convnext_base_fp16_512x512_160k_ade20k_20220227_181227-02a24fc6.pth --input-img demo/demo.png --show --verify  --cfg-options model.test_cfg.mode='whole'

Could you try removing --dynamic-export in your command? And I'll test it with torch 1.12.

xiexinch avatar Aug 08 '22 06:08 xiexinch

Hi @Aspirinkb, I can run the following command successfully with torch1.9 on the CPU.

python tools/pytorch2onnx.py configs/convnext/upernet_convnext_base_fp16_512x512_160k_ade20k.py --checkpoint checkpoints/upernet_convnext_base_fp16_512x512_160k_ade20k_20220227_181227-02a24fc6.pth --input-img demo/demo.png --show --verify  --cfg-options model.test_cfg.mode='whole'

Could you try removing --dynamic-export in your command? And I'll test it with torch 1.12.

Same error with torch1.12. Hi @RunningLeon, and @zhouzaida, Could you take a look at this issue if you're available?

xiexinch avatar Aug 08 '22 07:08 xiexinch

I change the signature and body of the forward function as following:

    @auto_fp16(apply_to=('img', ))
    def forward(self, img, **kwargs):  # img_metas, return_loss=True
        """Calls either :func:`forward_train` or :func:`forward_test` depending
        on whether ``return_loss`` is ``True``.

        Note this setting will change the expected inputs. When
        ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor
        and List[dict]), and when ``resturn_loss=False``, img and img_meta
        should be double nested (i.e.  List[Tensor], List[List[dict]]), with
        the outer list indicating test time augmentations.
        """
        try:
            return_loss = kwargs.pop("return_loss")
            img_metas = kwargs.pop("img_metas")
        except KeyError as e:
            raise Exception(f"Miss params return_loss or img_meats: {e}")
        if return_loss:
            return self.forward_train(img, img_metas, **kwargs)
        else:
            return self.forward_test(img, img_metas, **kwargs)

and export the onnx model. Running by onnxruntime is ok.
But please note that, one should change it back for training models!!!

It is not the correct way to fix the onnx export problem, but I can not do nothing...
Waiting for official bug fix...

Aspirinkb avatar Aug 08 '22 11:08 Aspirinkb

Is there any progress on this issue? I have the same problem.

AndPuQing avatar Aug 17 '22 03:08 AndPuQing

something wrong with auto_fp16 and partial the model forward. cc @zhouzaida Could comment out the @auto_fp16 decorator while using pytorch2onnx.py for now. https://github.com/open-mmlab/mmsegmentation/blob/dd42fa8d0125632371a41a87c20485494c973535/mmseg/models/segmentors/base.py#L96

https://github.com/open-mmlab/mmsegmentation/blob/dd42fa8d0125632371a41a87c20485494c973535/tools/pytorch2onnx.py#L170

RunningLeon avatar Aug 17 '22 06:08 RunningLeon

@Aspirinkb @AndPuQing Hi, could you guys try with mmdeploy? The deployment feature in mmseg would be removed in the future.

RunningLeon avatar Aug 17 '22 07:08 RunningLeon

I tried mmdeploy and it worked for me.

AndPuQing avatar Aug 20 '22 16:08 AndPuQing