onnx-mlir
onnx-mlir copied to clipboard
onnx model MLIR conversion failing
onnx model MLIR conversion is failing after successful onnx export.
onnx_model_path = "/vlaWithDynamo.onnx"
torch.onnx.export(vla, (inputs["input_ids"], inputs["attention_mask"], inputs["pixel_values"]), onnx_model_path, input_names=["input_node"], output_names=["output_node"], dynamo=True, report=True, export_params=True)
./onnx-mlir --EmitMLIR /vlaWithDynamo.onnx [1/3] Wed Apr 16 16:00:24 2025 (0s) Importing ONNX Model to MLIR Module from "vlaWithDynamo.onnx" [2/3] Wed Apr 16 16:00:24 2025 (0s) Compiling and Optimizing MLIR Module 'onnx.Pow' op Pow with different input type not implemented yet
System Info:
Linux Ubuntu 22.04*): ONNX version: Version: 1.17.0 Python version: 3.10.12 Protobuf version: 6.30.2 torch: 2.8.0.dev20250415+cpu onnx-mlir version: 0.4.2 Code: from transformers import AutoModelForVision2Seq, AutoProcessor from PIL import Image import numpy as np import torch import time import onnx from onnxruntime.training import artifacts
from pathlib import Path
processor = AutoProcessor.from_pretrained("openvla/openvla-7b", trust_remote_code=True)
vla = AutoModelForVision2Seq.from_pretrained( "openvla/openvla-7b", #attn_implementation="flash_attention_2", # [Optional] Requires flash_attn torch_dtype=torch.float32, low_cpu_mem_usage=True, trust_remote_code=True )
onnx_model_path = "/vlaWithDynamo.onnx"
torch.onnx.export(vla, (inputs["input_ids"], inputs["attention_mask"], inputs["pixel_values"]), onnx_model_path, input_names=["input_node"], output_names=["output_node"],dynamo=True, report=True, export_params=True)
Expected Behavior: onnx model to generate its MLIR form without any issues.
Hi @kraza8, as the error message said, the Pow currently does not support different types for base and power. It supports the case where both base and power are float only. Need help from someone to add support for the case of different types.
so i can get onnx-mlir from torch.pow(x, y) with x and y mixed types. it just casts one of the inputs appropriately before applying onnx.Pow. you get something like
func.func @forward(%arg0: tensor<64x128xi64> {onnx.name = "input_0"}, %arg1: tensor<64x128xf32> {onnx.name = "input_1"}) -> (tensor<64x128xf32> {onnx.name = "3"}) attributes {llvm.emit_c_interface} {
%0 = "onnx.Cast"(%arg0) {onnx_node_name = "//Cast", saturate = 1 : si64, to = f32} : (tensor<64x128xi64>) -> tensor<64x128xf32>
%1 = "onnx.Pow"(%arg1, %0) {onnx_node_name = "//Pow"} : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32>
return %1 : tensor<64x128xf32>
}
so the question is why this is not happening when compiling this model. but i think that's more of a torch.onnx.export issue
This can be potentially easy to resolve, but reading the specs, I have no idea how different types would be handled. Maybe converting the smaller one into the larger one?
@srcarroll I find that digging in ONNX Runtime code sometime helpful to see how a well accepted ONNX consumer is doing it. Potentially, one could also ask the ONNX slack channel on what is accepted practice.
so when i say torch.pow(x, y) compiles fine thats with torch.onnx.export and dynamo=False. so by the time we want to convert to onnx-mlir the graph is already doing the appropriate casting as i showed in the example above. it's the dynamo=True case that produces this failure. so again, i think this is more of a question of exporter behavior with different dynamo settings.
i find that dynamo=True almost never works. but then when things don't work for dynamo=False, the torch people just say use dynamo=True, but that often just causes more problems. i dont think this is the concern of onnx-mlir though
This can be potentially easy to resolve, but reading the specs, I have no idea how different types would be handled. Maybe converting the smaller one into the larger one?
It seems clear from the spec:
https://onnx.ai/onnx/operators/onnx__Pow.html
Inputs X (heterogeneous) - T:
First operand, base of the exponent.
Y (heterogeneous) - T1:
Second operand, power of the exponent.
Outputs Z (heterogeneous) - T:
Output tensor
so when i say
torch.pow(x, y)compiles fine thats withtorch.onnx.exportanddynamo=False. so by the time we want to convert toonnx-mlirthe graph is already doing the appropriate casting as i showed in the example above. it's thedynamo=Truecase that produces this failure. so again, i think this is more of a question of exporter behavior with different dynamo settings.i find that
dynamo=Truealmost never works. but then when things don't work fordynamo=False, the torch people just say usedynamo=True, but that often just causes more problems. i dont think this is the concern ofonnx-mlirthough
To be clear, the usage of Pow in the pytorch exporter conforms with the onnx spec as linked above. It is true though some backends do not implement support for cases when the input types are different.
This can be potentially easy to resolve, but reading the specs, I have no idea how different types would be handled. Maybe converting the smaller one into the larger one?
It seems clear from the spec:
https://onnx.ai/onnx/operators/onnx__Pow.html
Inputs X (heterogeneous) - T:
First operand, base of the exponent.
Y (heterogeneous) - T1:
Second operand, power of the exponent.
Outputs Z (heterogeneous) - T:
Output tensor
I dont see how the rules of type casting is clear from this. all it says is what the operands and result is.
I'm sure there's a general rule for casting that isn't specific to this op. I dont know what that is, but what @AlexandreEichenberger mentioned makes most sense to me. But just saying "it's clear from spec" and providing no explanation about how you go from this spec to casting rules isn't very helpful
I dont know what that is, but what @AlexandreEichenberger mentioned makes most sense to me
i should say makes most sense for simple cases, like when both are floating or both are int. i think it's obvious that if one operand is a float, the other must be casted to some float. but what if the float type is smaller than the int type? do you make the float type large enough to not destroy the value of the int type? or do you just ignore that, cast to the original float type, and leave it up to the user to be careful about that in the first place?
Thanks for pointing this out. Reading it again I see that the spec doesn't say what dtype the computation should be performed in; but that's true for most other ops in the spec - although less ambiguous because it is usually "the same type as the inputs."
IIRC ORT casts T1 to T.
I think we may instead update the spec to constraint the input types to be the same, similar to https://openxla.org/stablehlo/spec#power
@justinchuby Did not know that you are "reading" our issues, big thanks for your input, super valuable.
Agreed on updating the specs, restricting or being clearer.
Probably because of backward compatibility, we (aka onnx-mlir, ORT) have to come up with a policy on what to do with different types, so what ORT does make sense, casting the power number to the input/output.
@srcarroll do you want to implement this, or should I. I think its as simple as adding a create.math.cast. You can even put it unconditionally, as same type casts are detected and treated as a no-op by the MathBuilder.
The only place where it might be a bit touchy is that we have pattern that recognized pow(x, n) where n is small integer value (or rounding to a small integer value) and substitute it by mult operators... e.g. pow(x, 2.0) -> mul(x, x). This part might be sensitive to the types of the two inputs. But this distinct type situation might be a rare event, so it might not warrant worrying about it.
Improved in https://github.com/microsoft/onnxscript/pull/2335
Improved in microsoft/onnxscript#2335
thanks @justinchuby . just curious why only for tensor_scalar case? is it not identical logic for tensor_tensor?
@srcarroll do you want to implement this, or should I. I think its as simple as adding a create.math.cast. You can even put it unconditionally, as same type casts are detected and treated as a no-op by the MathBuilder.
@AlexandreEichenberger sure i can do that. but would the onnxscript changes suffice or should onnx-mlir also handle it?
I added a todo in the PR for tensor_tensor. We don't want to do manual type promotion to complicate the implementation unless absolutely necessary (will be fixed with a more general mechanism). From practical experience we don't see tensor_tensor being used.
PR #3175 updates Pow to support type conversions.
PR #3175 updates Pow to support type conversions.
cool thanks @AlexandreEichenberger . i'll post this followup question in your PR too, so feel free to answer wherever you think appropriate.
i see that you are now allowing mixed types in the onnx.Pow definition and handling the appropriate type cast in the to-kernel conversion. but i wonder if it's more appropriate to keep the type conditions as is, and instead of handling in the to-kernel conversion handle it in the onnx to onnx-mlir conversion with the onnx-mlir tool.
I dont know what's best from a design standpoint but I think that other conversions from onnx mlir dialect will still have to handle the mixed types. for example the to-stablehlo and to-tosa conversions. So i guess my question is, should this be the responsibility of the specific backend consuming onnx-mlir? or should it be inherent to the onnx to onnx-mlir conversion from the start?