pytorch-pfn-extras icon indicating copy to clipboard operation
pytorch-pfn-extras copied to clipboard

[PFTO] Support ONNX export with list inputs

Open xuzijian629 opened this issue 2 years ago • 4 comments

To support list inputs and models with list operators in ppe.onnx.export_testcase, we have two topics.

How to handle list inputs

Roughly, we have two choices

  • Unroll lists and assume that ppe.onnx.export always exports a onnx and testcase whose inputs are all Tensors.
  • Allow list inputs (i.e., Sequence typed inputs in ONNX).

I think we should go with the first way because torch.onnx.export does so. Thus, torch.onnx.export actually accepts calls like torch.onnx.export(model, (list_arg,), input_names=["a", "b", "c"], ...).

This is because torch.onnx.export calls torch.jit._get_trace_graph and list inputs are automatically unrolled in it. However, this API is internal and with public torch.jit.trace, list inputs are kept as list. (For more detail, see https://github.com/pfnet/pytorch-pfn-extras/issues/572#issuecomment-1184040553).

Since PFTO uses torch.jit.trace, one viable way is to create a wrapper model that accepts unrolled inputs.

Add support for more list operators

We have to implement custom symbolic execution for prim::ListUnpack, prim::TupleConstruct, etc... (prim::ListConstruct, which is the most used one, is already implemented).

I haven't fully understood the symbolic execution of prim::ListConstruct-like nodes in torch.onnx.export. For future survey, I leave some memo:

  • From torch v1.12.0 symbolic functions for list ops were introduced: symbolic_opset9.py (but we can export before v1.11.0 also)
  • torch.onnx.export does onnx_peephole optimization which includes eraseListConstruct or eraseListUnpack, after _C._jit_pass_onnx(graph, operator_export_type). In my understanding, prim::ListConstruct and similar ops are replaced with onnx operators by symbolic execution. Why they remain after _C.jit_pass_onnx?

It seems to me that the handling of prim::ListConstruct-like ops has not been stable. Maybe we should wait a little bit to stabilize our implementation.

xuzijian629 avatar Jul 13 '22 08:07 xuzijian629

We can add custom handler of ListUnpack as ListConstruct here https://github.com/pfnet/pytorch-pfn-extras/blob/a27e3d4030dcdeeedef11ad2ca4cd022c47b45c2/pytorch_pfn_extras/onnx/pfto_exporter/export.py#L458-L463

xuzijian629 avatar Jul 14 '22 01:07 xuzijian629

PFTO seems to enable onnx_peephole by default, so the peephole optimization of eraseListConstruct and eraseListUnpack also runs in PFTO. However, currently, onnx_peephole optimization is placed after run_symbolic_function (in generate_onnx_node).

xuzijian629 avatar Jul 14 '22 02:07 xuzijian629

torch.onnx.export seems to automatically unroll list inputs.

This is done in torch.jit._get_trace_graph https://github.com/pytorch/pytorch/blob/05ce013338b3882136eea394c37c57e29e43df1a/torch/jit/_trace.py#L95

This API is assumed to be internal and they recommend to use torch.jit.trace for public use. However, torch.jit.trace doesn't unroll list inputs

xuzijian629 avatar Jul 14 '22 06:07 xuzijian629

It seems torch.onnx.export exports sequence inputs when model is scripted (since script modules don't know the number of tensors in lists). So, sequence inputs is essential for scripted models.

xuzijian629 avatar Jul 19 '22 01:07 xuzijian629