ao icon indicating copy to clipboard operation
ao copied to clipboard

How does this work with ONNX export and quantization?

Open ogencoglu opened this issue 1 year ago • 6 comments

Does quantized models here become quantized models in ONNX after conversion? Can you even convert/export them to ONNX? How about other way around? Can you export a sparse model to ONNX and quantize in ONNX afterwards?

ogencoglu avatar Aug 29 '24 11:08 ogencoglu

We haven't really experimented much with ONNX so far. Though we do support export and once you export a model you can use an ONNX backend

  1. Step1: Export an AO model https://github.com/pytorch/ao/tree/main/torchao/quantization#workaround-with-unwrap_tensor_subclass-for-export-aoti-and-torchcompile-pytorch-24-and-before-only
  2. Step2: Use the ONNX backend https://pytorch.org/tutorials/beginner/onnx/export_simple_model_to_onnx_tutorial.html

If you wanna work through an example and post your progress here, happy to unblock you! We can add some example in the repo

msaroufim avatar Aug 29 '24 15:08 msaroufim

Using the procedure outlined above, I successfully got a simple quantization model exported. However, when applying the same approach to a custom, nested, large model, torch.export.export() throws the error shown below. To bypass the unwrap_tensor_subclass issue, I upgraded to Torch 2.5.1. Any guidance or suggestions to resolve this would be greatly appreciated.

Traceback (most recent call last): File "<string>", line 1, in <module> File "<torch_path>/torch/export/__init__.py", line 270, in export return _export( ^^^^^^^^ File "<torch_path>/torch/export/_trace.py", line 1017, in wrapper raise e File "<torch_path>/torch/export/_trace.py", line 990, in wrapper ep = fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "<torch_path>/torch/export/exported_program.py", line 114, in wrapper return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "<torch_path>/torch/export/_trace.py", line 1880, in _export export_artifact = export_func( # type: ignore[operator] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "<torch_path>/torch/export/_trace.py", line 1224, in _strict_export return _strict_export_lower_to_aten_ir( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "<torch_path>/torch/export/_trace.py", line 1333, in _strict_export_lower_to_aten_ir aten_export_artifact = lower_to_aten_callback( ^^^^^^^^^^^^^^^^^^^^^^^ File "<torch_path>/torch/export/_trace.py", line 637, in _export_to_aten_ir gm, graph_signature = transform(aot_export_module)( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "<torch_path>/torch/_functorch/aot_autograd.py", line 1246, in aot_export_module fx_g, metadata, in_spec, out_spec = _aot_export_function( ^^^^^^^^^^^^^^^^^^^^^ File "<torch_path>/torch/_functorch/aot_autograd.py", line 1480, in _aot_export_function fx_g, meta = create_aot_dispatcher_function( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "<torch_path>/torch/_functorch/aot_autograd.py", line 522, in create_aot_dispatcher_function return _create_aot_dispatcher_function( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "<torch_path>/torch/_functorch/aot_autograd.py", line 729, in _create_aot_dispatcher_function raise RuntimeError( RuntimeError: aot_export is not currently supported with traceable tensor subclass. If you need this feature, please comment on <CREATE_ISSUE_LINK>

hahnec avatar Dec 13 '24 21:12 hahnec

@hahnec did you use unwrap_tensor_class before calling export?

jerryzh168 avatar Dec 14 '24 05:12 jerryzh168

@jerryzh168 Yes, I used unwrap_tensor_subclass with version 2.4.x. However, then I upgraded, so the above error is from version 2.5.1 without using unwrap_tensor_subclass.

EDIT I need to mention that my custom model successfully gets converted with torch.onnx.export() without quantization. When using unwrap_tensor_subclass(model) with v2.5.1, I receive:

Exception has occurred: AttributeError 'UnwrapTensorSubclass' object has no attribute 'rebuild_stack' File "<path>/onnx_test2.py", line 28, in <module> model = unwrap_tensor_subclass(model) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ AttributeError: 'UnwrapTensorSubclass' object has no attribute 'rebuild_stack'

hahnec avatar Dec 14 '24 07:12 hahnec

@hahnec so what is the difference of quantization before conversion and quantizing the onnx model?

ogencoglu avatar May 14 '25 17:05 ogencoglu

@jerryzh168 Yes, I used unwrap_tensor_subclass with version 2.4.x. However, then I upgraded, so the above error is from version 2.5.1 without using unwrap_tensor_subclass.

EDIT I need to mention that my custom model successfully gets converted with torch.onnx.export() without quantization. When using unwrap_tensor_subclass(model) with v2.5.1, I receive:

Exception has occurred: AttributeError 'UnwrapTensorSubclass' object has no attribute 'rebuild_stack' File "<path>/onnx_test2.py", line 28, in <module> model = unwrap_tensor_subclass(model) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ AttributeError: 'UnwrapTensorSubclass' object has no attribute 'rebuild_stack'

can you do torch.export first before doing onnx export?

jerryzh168 avatar May 14 '25 19:05 jerryzh168