pytorch-image-models icon indicating copy to clipboard operation
pytorch-image-models copied to clipboard

[BUG] The ONNX export for VOLO does not work

Open WildChlamydia opened this issue 2 years ago • 5 comments

Describe the bug VOLO models can not be exported to ONNX:

==> Creating PyTorch volo_d1_224 model
.../lib/python3.8/site-packages/safetensors/torch.py:99: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  with safe_open(filename, framework="pt", device=device) as f:
pytorch-image-models/timm/utils/onnx.py:71: FutureWarning: 'torch.onnx._export' is deprecated in version 1.12.0 and will be removed in 2.0. Please use `torch.onnx.export` instead.
  torch_out = torch.onnx._export(
pytorch-image-models/timm/models/volo.py:74: TracerWarning: Converting a tensor to a Python float might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  h, w = math.ceil(H / self.stride), math.ceil(W / self.stride)
======= Diagnostic Run torch.onnx.export version 2.1.0.dev20230710+cu118 =======
verbose: False, log level: 40
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================

Traceback (most recent call last):
  File "onnx_export.py", line 86, in <module>
    main()
  File "onnx_export.py", line 74, in main
    onnx_export(
  File "pytorch-image-models/timm/utils/onnx.py", line 71, in onnx_export
    torch_out = torch.onnx._export(
  File ".../lib/python3.8/site-packages/torch/onnx/_deprecation.py", line 30, in wrapper
    return function(*args, **kwargs)
  File ".../lib/python3.8/site-packages/torch/onnx/__init__.py", line 114, in _export
    return utils._export(*args, **kwargs)
  File ".../lib/python3.8/site-packages/torch/onnx/utils.py", line 1577, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File ".../lib/python3.8/site-packages/torch/onnx/utils.py", line 1134, in _model_to_graph
    graph = _optimize_graph(
  File ".../lib/python3.8/site-packages/torch/onnx/utils.py", line 672, in _optimize_graph
    graph = _C._jit_pass_onnx(graph, operator_export_type)
  File ".../lib/python3.8/site-packages/torch/onnx/utils.py", line 1919, in _run_symbolic_function
    return symbolic_fn(graph_context, *inputs, **attrs)
  File ".../lib/python3.8/site-packages/torch/onnx/symbolic_helper.py", line 306, in wrapper
    return fn(g, *args, **kwargs)
  File ".../lib/python3.8/site-packages/torch/onnx/symbolic_opset18.py", line 52, in col2im
    num_dimensional_axis = symbolic_helper._get_tensor_sizes(output_size)[0]
TypeError: 'NoneType' object is not subscriptable
(Occurred when translating col2im).

Do you have any ideas about what the problem might be? Thank you in advance.

To Reproduce Steps to reproduce the behavior: 1.

python onnx_export.py \
volo.onnx \
--model volo_d1_224 \
--img-size 224 \
--b 1 \
--opset 18

Desktop (please complete the following information):

  • OS: Ubuntu 20.04.3 LTS
  • This repository version [e.g. pip 0.3.1 or commit ref]: 394e8145551191ae60f672556936314a20232a35
  • PyTorch version w/ CUDA/cuDNN 2.1.0.dev20230710+cu118

WildChlamydia avatar Jul 10 '23 22:07 WildChlamydia

Hmm, I'm 99% sure the issue is in this block (https://github.com/huggingface/pytorch-image-models/blob/394e8145551191ae60f672556936314a20232a35/timm/models/volo.py#L70-L91), something to do with the unfold/fold (col2im is fold) and handling of the shapes

Fold support was just added to ONNX in opset 18 that you are using, so maybe there isn't full coverage of the op? or doesn't work well with dynamic (based on input shape) size? Can opset 19 be tried?

rwightman avatar Jul 11 '23 04:07 rwightman

Thank you for your response.

It should be fully supported. However, I attempted to use opset 19 and it appears that, currently, opset 19 is supported by ONNX but not by torch.onnx.

Maybe this issue can help: https://github.com/huggingface/pytorch-image-models/pull/1708

WildChlamydia avatar Jul 12 '23 21:07 WildChlamydia

Okay, it seems that the problem lies in the type in the Fold: https://github.com/pytorch/pytorch/issues/105134#issuecomment-1635999042.

However, next I encountered another problem related to exporting VOLO: https://github.com/pytorch/pytorch/issues/97344.

To workaround these issues, I manually rewrote the graph by replacing the node attribute 'axes' with a fixed input. Now everything works but only with aten-fallback and opset 18.

Should I prepare a PR with this temporary solution?

WildChlamydia avatar Jul 14 '23 16:07 WildChlamydia

@WildChlamydia hmm, needing aten fallback is going to make it not very pratical for many users wanting ONNX export to use outside of the torch ecosystem no?

rwightman avatar Jul 18 '23 06:07 rwightman

@rwightman

I suppose you are right.

However, without the --aten-fallback flag, torch performs this check:

if (operator_export_type is _C_onnx.OperatorExportTypes.ONNX) and (
    not val_use_external_data_format
):
    try:
        _C._check_onnx_proto(proto)
    except RuntimeError as e:
        raise errors.CheckerError(e) from e

As a result, the model won't be saved. There are only bad solutions. It is still possible to export the model successfully by commenting out these lines, and everything works fine.

I think we should wait until all bugs in opset18 are fixed.

WildChlamydia avatar Jul 18 '23 09:07 WildChlamydia

I tried this again, still not working, get past one layer of the export and another is broken. I tried newer dynamo export but that is also broken (and broken on many more models). Sigh.

I don't feel there is much I can do on the timm end of things. So closing this.

rwightman avatar Apr 10 '24 16:04 rwightman