onnx2torch
onnx2torch copied to clipboard
Support for exporting traced graph?
Hi, I have a usecase that ingests ONNX models, and needs to convert to PyTorch then export a traced graph (via torch.export.export()
).
After converting ONNX to torch (via onnx2torch.convert()
), I'm running into issues tracing the graph, due to dynamic flow control in the converted torch model.
Is there any plan for onnx2torch
to support this type of usecase? Or are there any recommendations for how to workaround the dynamic flow control in the converted torch model?
An example of a problematic op is reshape - the converted torch model has logic that is conditional on the input shape parameter, to replicate ONNX's special handling of shape dimensions that have value of 0 (meaning use input shape for that dim).
Here's code to reproduce that issue:
import io
import onnx
import onnx2torch
import torch
# Create ONNX model containing a single reshape node:
class M(torch.nn.Module):
def forward(self, x):
x = x.reshape(20, 10)
return x
torch_args = (torch.rand(10, 20),)
with io.BytesIO() as tmp_file:
torch.onnx.export(model=M(), args=torch_args, f=tmp_file)
onnx_model = onnx.load_from_string(tmp_file.getvalue())
# convert onnx --> torch
converted_torch = onnx2torch.convert(onnx_model)
# export traced graph (ExportedProgram):
ep = torch.export.export(converted_torch, args=torch_args)
This raises the following error (snippet - actual trace is very long):
...
UserError: Dynamic control flow is not supported at the moment. Please use functorch.experimental.control_flow.cond to explicitly capture the control flow. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#cond-operands
from user code:
File "<eval_with_key>.1", line 6, in forward
reshape = self.Reshape(input_1, constant); input_1 = constant = None
File "/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.8/site-packages/onnx2torch/node_converters/reshape.py", line 36, in forward
return _forward()
File "/usr/local/lib/python3.8/site-packages/onnx2torch/node_converters/reshape.py", line 31, in _forward
return self._do_reshape(input_tensor, shape)
File "/usr/local/lib/python3.8/site-packages/onnx2torch/node_converters/reshape.py", line 20, in _do_reshape
if torch.any(shape == 0):
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
Any insights would be appreciated. Thanks!