pytorch-pfn-extras
pytorch-pfn-extras copied to clipboard
Generate name for each member of list arg
Problem
ppe.onnx.export_testcase
fails with list input. This is because the number of input_names
is incorrect. The number of input tensors in the exported onnx is total tensor numbers, not len(args)
.
Reproduction
import torch
import pytorch_pfn_extras.onnx as tou
import torch.onnx.symbolic_helper
import torch.onnx.symbolic_registry
@torch.onnx.symbolic_helper.parse_args("v", "b", "f")
def pad_sequence(g, input, batch_first, padding_value):
assert batch_first, "batch_first=False is not supported"
ret = g.op("org.chainer::ChainerSequencePad", input, value_f=padding_value)
return ret
for opset in range(9, 16):
torch.onnx.symbolic_registry.register_op("pad_sequence", pad_sequence, "", opset)
class Model(torch.nn.Module):
def forward(self, xs):
return torch._C._nn.pad_sequence(xs, True, 0)
model = Model()
args = ([torch.rand(2, 5), torch.rand(3, 5)],)
torch.onnx.export(model, args, "test.onnx") # Success!!
tou.export_testcase(model, args, "test") # Fails!!
Error:
Traceback (most recent call last):
File "hoge.py", line 36, in <module>
tou.export_testcase(model, args, "test", use_pfto=False)
File "/mnt/vol21/joe/pfvm/third_party/pytorch-pfn-extras/pytorch_pfn_extras/onnx/export_testcase.py", line 303, in export_testcase
used_input_index_list.append(input_names.index(used_input.name))
ValueError: 'onnx::SequenceConstruct_1' is not in list
torch version: 1.12.0
This PR
Unrolls the input list args and generate names for all tensors.
For example, if the input args is args=([[a,b],c], d)
, the generated input names are
[input_0_0_0, input_0_0_1, input_0_1, input_1]
whereas current master generates names as
[input_0, input_1]
/test
todo: add test
Update: I summarized the problem of extending current ppe.onnx.export
for models with list inputs at https://github.com/pfnet/pytorch-pfn-extras/issues/572.
In short, torch.onnx.export
automatically unrolls list inputs in their internal trace API (torch.jit._get_trace_graph
). However, torch's public API for trace (torch.jit.trace
) does not unroll list inputs. So maybe we should create a wrapper class for export that unroll list inputs. How do you think about this?
With 6c732649a7ffda442d9c3beaf12427f07e2f0375, onnx export with use_pfto=False
works well (because then pee.onnx.export
will delegate to torch.onnx.export
and list inputs are automatically unrolled, so it makes sense to unroll inputs to generate names).
Note that with use_pfto=True
, currently, exported onnx has single list input (that is sequence typed in ONNX) and we MUST NOT unroll input args.
My idea: always unroll args to follow torch.onnx.export
style. For PFTO, we wrap the module for list inputs (future work).
memo: Also support list, tuple outputs