pytorch-pfn-extras
pytorch-pfn-extras copied to clipboard
[PFTO] Support ONNX export with list inputs
To support list inputs and models with list operators in ppe.onnx.export_testcase
, we have two topics.
How to handle list inputs
Roughly, we have two choices
- Unroll lists and assume that
ppe.onnx.export
always exports a onnx and testcase whose inputs are allTensor
s. - Allow list inputs (i.e., Sequence typed inputs in ONNX).
I think we should go with the first way because torch.onnx.export
does so.
Thus, torch.onnx.export actually accepts calls like torch.onnx.export(model, (list_arg,), input_names=["a", "b", "c"], ...)
.
This is because torch.onnx.export calls torch.jit._get_trace_graph
and list inputs are automatically unrolled in it. However, this API is internal and with public torch.jit.trace
, list inputs are kept as list. (For more detail, see https://github.com/pfnet/pytorch-pfn-extras/issues/572#issuecomment-1184040553).
Since PFTO uses torch.jit.trace
, one viable way is to create a wrapper model that accepts unrolled inputs.
Add support for more list operators
We have to implement custom symbolic execution for prim::ListUnpack
, prim::TupleConstruct
, etc... (prim::ListConstruct
, which is the most used one, is already implemented).
I haven't fully understood the symbolic execution of prim::ListConstruct
-like nodes in torch.onnx.export
.
For future survey, I leave some memo:
- From torch v1.12.0 symbolic functions for list ops were introduced: symbolic_opset9.py (but we can export before v1.11.0 also)
-
torch.onnx.export
doesonnx_peephole
optimization which includeseraseListConstruct
oreraseListUnpack
, after_C._jit_pass_onnx(graph, operator_export_type)
. In my understanding,prim::ListConstruct
and similar ops are replaced with onnx operators by symbolic execution. Why they remain after_C.jit_pass_onnx
?
It seems to me that the handling of prim::ListConstruct
-like ops has not been stable. Maybe we should wait a little bit to stabilize our implementation.
We can add custom handler of ListUnpack
as ListConstruct
here
https://github.com/pfnet/pytorch-pfn-extras/blob/a27e3d4030dcdeeedef11ad2ca4cd022c47b45c2/pytorch_pfn_extras/onnx/pfto_exporter/export.py#L458-L463
PFTO seems to enable onnx_peephole
by default, so the peephole optimization of eraseListConstruct
and eraseListUnpack
also runs in PFTO. However, currently, onnx_peephole
optimization is placed after run_symbolic_function
(in generate_onnx_node
).
torch.onnx.export
seems to automatically unroll list inputs.
This is done in torch.jit._get_trace_graph
https://github.com/pytorch/pytorch/blob/05ce013338b3882136eea394c37c57e29e43df1a/torch/jit/_trace.py#L95
This API is assumed to be internal and they recommend to use torch.jit.trace
for public use. However, torch.jit.trace
doesn't unroll list inputs
It seems torch.onnx.export
exports sequence inputs when model is scripted (since script modules don't know the number of tensors in lists).
So, sequence inputs is essential for scripted models.