[ONNX] Use the `dynamo=True` option from PyTorch 2.5
Feature request
PyTorch 2.5 introduces the torch.onnx.export(..., dynamo=True) option and an improved export logic for converting models to ONNX. We recommend optimum to leverage this feature.
Motivation
The torch.onnx.export(..., dynamo=True) logic leverages new model capturing mechanism offered by torch.export. It is how PyTorch will capture models moving forward. Additionally, we have better optimization and external data support built into this logic path.
This is a continuation of https://github.com/huggingface/optimum/pull/1712
Your contribution
I can create a PR
Hi ! that's super cool, so we won't need to call torch.onnx.dynamo_export(..) ? and how much are these two different ?
Hi! The dynamo=True option in PyTorch 2.5 uses a new logic the leverages torch.export ExportedProgram and the ONNX IR for capturing graph and constructing ONNX models. It is more robust and produces more optimized graphs than the current dynamo_export. dynamo_export will be updated to use the same logic starting from PyTorch 2.6 (tentative).
In my experience, using dynamo=True has had huge loss in accuracy in downstream formats (onnx and trt) for detection models like DFINE.
In my experience, using
dynamo=Truehas had huge loss in accuracy in downstream formats (onnx and trt) for detection models like DFINE.
Could you share example models?
@justinchuby also following up from this: issue
Do we have a way to enable modules_as_functions in dynamo onnx export?
When i tried doing via this, it doesn't encapsulate nn.module as function. (is it completely deprecated in torch==2.9.0)
onnx_program = torch.onnx.export(model,(tensor_input, dict_input, list_input), dynamic_shapes=({0: "batch_size"},{"tensor_x": {0: "batch_size"}},[{0: "batch_size"}]), input_names=input_names, output_names=output_names, dynamo=True, export_modules_as_functions=True, do_constant_folding=True)
There isn’t an option. What’s the use case? Users are encouraged to write their own passes to do this using onnx-ir and metadata associated with the model.
@justinchuby the issue resolved once I upgrade to torch2.9. So all good !
There isn’t an option. What’s the use case? Users are encouraged to write their own passes to do this using onnx-ir and metadata associated with the model.
@justinchuby Thanks, is there a documentation on how to write passes, for ex: to enable functions?? and have similar structure of graph we have currently with export_modules_as_functions={MODULES}, The use case is to have repeated blocks with single unique function proto signature, so that it's easier to optimize or speedup the compilation process.
There isn’t an option. What’s the use case? Users are encouraged to write their own passes to do this using onnx-ir and metadata associated with the model.
Hi, @justinchuby, export_modules_as_functions is very convenient for writing ONNX's local function and loading them in TensorRT. A typical example is exporting the weightless part of attention module from an LLM as a parameterized function, and then reimplementing continuous batching and FlashInfer-based paged inference inside TensorRT plugins. Without export_modules_as_functions, implementing similar functionality (e.g. through onnx plugin) would be more cumbersome.
class LocalFunctionForAttn(torch.nn.Module):
params: str
def __init__(self):
super().__init__()
self.params = ""
def add_params(self, key, value):
if len(self.params) != 0:
self.params += ";"
self.params += f"{key}={value}"
def forward(self, q, k, v):
return q + k + v
https://github.com/torchpipe/torchpipe/blob/7c6901844527db1be59a0eaf5fb5c8e5927fe99a/plugins/torchpipe/examples/llama2/models/export_onnx_v2.py#L225