optimum icon indicating copy to clipboard operation
optimum copied to clipboard

[ONNX] Use the `dynamo=True` option from PyTorch 2.5

Open justinchuby opened this issue 1 year ago • 9 comments

Feature request

PyTorch 2.5 introduces the torch.onnx.export(..., dynamo=True) option and an improved export logic for converting models to ONNX. We recommend optimum to leverage this feature.

Motivation

The torch.onnx.export(..., dynamo=True) logic leverages new model capturing mechanism offered by torch.export. It is how PyTorch will capture models moving forward. Additionally, we have better optimization and external data support built into this logic path.

This is a continuation of https://github.com/huggingface/optimum/pull/1712

Your contribution

I can create a PR

justinchuby avatar Sep 17 '24 02:09 justinchuby

Hi ! that's super cool, so we won't need to call torch.onnx.dynamo_export(..) ? and how much are these two different ?

IlyasMoutawwakil avatar Sep 20 '24 10:09 IlyasMoutawwakil

Hi! The dynamo=True option in PyTorch 2.5 uses a new logic the leverages torch.export ExportedProgram and the ONNX IR for capturing graph and constructing ONNX models. It is more robust and produces more optimized graphs than the current dynamo_export. dynamo_export will be updated to use the same logic starting from PyTorch 2.6 (tentative).

justinchuby avatar Oct 01 '24 17:10 justinchuby

In my experience, using dynamo=True has had huge loss in accuracy in downstream formats (onnx and trt) for detection models like DFINE.

q-prashant avatar Nov 11 '25 05:11 q-prashant

In my experience, using dynamo=True has had huge loss in accuracy in downstream formats (onnx and trt) for detection models like DFINE.

Could you share example models?

justinchuby avatar Nov 11 '25 16:11 justinchuby

@justinchuby also following up from this: issue Do we have a way to enable modules_as_functions in dynamo onnx export?

When i tried doing via this, it doesn't encapsulate nn.module as function. (is it completely deprecated in torch==2.9.0)

onnx_program = torch.onnx.export(model,(tensor_input, dict_input, list_input), dynamic_shapes=({0: "batch_size"},{"tensor_x": {0: "batch_size"}},[{0: "batch_size"}]), input_names=input_names, output_names=output_names, dynamo=True, export_modules_as_functions=True, do_constant_folding=True)

vbaddi avatar Nov 12 '25 09:11 vbaddi

There isn’t an option. What’s the use case? Users are encouraged to write their own passes to do this using onnx-ir and metadata associated with the model.

justinchuby avatar Nov 12 '25 21:11 justinchuby

@justinchuby the issue resolved once I upgrade to torch2.9. So all good !

pra-dan avatar Nov 13 '25 04:11 pra-dan

There isn’t an option. What’s the use case? Users are encouraged to write their own passes to do this using onnx-ir and metadata associated with the model.

@justinchuby Thanks, is there a documentation on how to write passes, for ex: to enable functions?? and have similar structure of graph we have currently with export_modules_as_functions={MODULES}, The use case is to have repeated blocks with single unique function proto signature, so that it's easier to optimize or speedup the compilation process.

vbaddi avatar Nov 13 '25 05:11 vbaddi

There isn’t an option. What’s the use case? Users are encouraged to write their own passes to do this using onnx-ir and metadata associated with the model.

Hi, @justinchuby, export_modules_as_functions is very convenient for writing ONNX's local function and loading them in TensorRT. A typical example is exporting the weightless part of attention module from an LLM as a parameterized function, and then reimplementing continuous batching and FlashInfer-based paged inference inside TensorRT plugins. Without export_modules_as_functions, implementing similar functionality (e.g. through onnx plugin) would be more cumbersome.

class LocalFunctionForAttn(torch.nn.Module):
    params: str

    def __init__(self):
        super().__init__() 
        self.params = ""

    def add_params(self, key, value):
        if len(self.params) != 0:
            self.params += ";"
        self.params += f"{key}={value}"
        
    def forward(self, q, k, v):
        return q + k + v

https://github.com/torchpipe/torchpipe/blob/7c6901844527db1be59a0eaf5fb5c8e5927fe99a/plugins/torchpipe/examples/llama2/models/export_onnx_v2.py#L225

tp-nan avatar Nov 27 '25 02:11 tp-nan