[OnnxToTorch] Fix Resize op when ONNX exports dynamic spatial dims as 0
ONNX encodes dynamic spatial dimensions in Resize as 0. This pass just passes the spatial dimensions as is for next steps. For interpolate/resize op, [0,0] spatial dimension is not valid proposed size. This change replaces such 0 values with the corresponding runtime dimension from the input tensor, ensuring correct shape propagation in Torch-MLIR and preventing invalid 0-sized dimensions.
This fixes the issue: https://github.com/iree-org/iree/issues/19501
Example for ONNX Model Export
>>> import torch
>>> import torch.nn as nn
>>> import torch.onnx
>>>
>>> class InterpolateModel(nn.Module):
... def forward(self, x):
... return nn.functional.interpolate(x, scale_factor=0.5, mode='bicubic', align_corners=False)
...
>>> model = InterpolateModel()
>>> dummy_input = torch.randn(11, 3, 32, 54)
>>>
>>> torch.onnx.export(
... model,
... dummy_input,
... "interpolate_dynamic_cubic.onnx",
... opset_version=11,
... input_names=['pixel_values'],
... output_names=['output'],
... dynamic_axes={
... 'pixel_values': {0: 'batch', 2: 'height', 3: 'width'},
... 'output': {0: 'batch', 2: 'height', 3: 'width'},
... }
... )
After exporting the model
>>>
>>> print("Exported interpolate_dynamic_cubic.onnx")
Exported interpolate_dynamic_cubic.onnx
>>> dynamic_input_model = onnx.load("interpolate_dynamic_cubic.onnx")
>>> for input_tensor in dynamic_input_model.graph.input:
... shape = [dim.dim_value for dim in input_tensor.type.tensor_type.shape.dim]
... print(f"{input_tensor.name}: {shape}")
...
pixel_values: [0, 3, 0, 0]
In IREE, one of the test using Hugging Model is failing because the interpolate/resize dimensions are [0, 0]
- Input (pixel_values)
- Get Shape
- Gather indices 2 and 3 (which is spatial dimension height and width)
- Divide both by 16 (patch size)
- Concat them.
But input pixrl_values are all [0, 0 , 0, 0]
import onnx
model = onnx.load("model.onnx")
>>> for input_tensor in model.graph.input:
... shape = [dim.dim_value for dim in input_tensor.type.tensor_type.shape.dim]
... print(f"{input_tensor.name}: {shape}")
...
pixel_values: [0, 0, 0, 0]
@HalfBloodPrince010 I'm not sure I understand this change.
The example you gave in the PR comment generates IR like:
module {
func.func @main_graph(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.7.1"} {
%none = torch.constant.none
%0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_> : tensor<4xf32>} : () -> !torch.vtensor<[4],f32>
%1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__1> : tensor<0xf32>} : () -> !torch.vtensor<[0],f32>
%2 = torch.operator "onnx.Resize"(%arg0, %1, %0) {torch.onnx.coordinate_transformation_mode = "half_pixel", torch.onnx.cubic_coeff_a = -7.500000e-01 : f32, torch.onnx.mode = "cubic", torch.onnx.nearest_mode = "floor"} : (!torch.vtensor<[?,3,?,?],f32>, !torch.vtensor<[0],f32>, !torch.vtensor<[4],f32>) -> !torch.vtensor<[?,3,?,?],f32>
return %2 : !torch.vtensor<[?,3,?,?],f32>
}
}
{-#
dialect_resources: {
builtin: {
_: "0x080000000000803F0000803F0000003F0000003F",
__1: "0x08000000"
}
}
#-}
Which converts to:
module {
func.func @main_graph(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.7.1"} {
%float5.000000e-01 = torch.constant.float 5.000000e-01
%none = torch.constant.none
%false = torch.constant.bool false
%str = torch.constant.str "cubic"
%0 = torch.prim.ListConstruct %float5.000000e-01, %float5.000000e-01 : (!torch.float, !torch.float) -> !torch.list<float>
%1 = torch.aten.__interpolate.size_list_scale_list %arg0, %none, %0, %str, %false, %none, %false : !torch.vtensor<[?,3,?,?],f32>, !torch.none, !torch.list<float>, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,3,?,?],f32>
return %1 : !torch.vtensor<[?,3,?,?],f32>
}
}
Which uses the scales and not the sizes (as it should, since the sizes arg was not provided in the original model).
Calling dim.dim_value for dim in input_tensor.type.tensor_type.shape.dim will return 0 to mean "empty" in protobuf. But dim.dim_param will return something like "batch", etc., which will get imported correctly as a dynamic dim.
Hi @zjgarvey, thank you for the response.
The example I gave was meant to illustrate how ONNX sometimes exports dynamic dims as 0. In the test case from https://github.com/iree-org/iree/issues/19501. When I visualized the onnx exported model produced from the test, I observed the following
input pixel_values [0, 0, 0, 0]
→ gather indices 2 & 3 (spatial dims)
→ divide by 16 (to match the patch grid)
→ concat
This serves as the sizes operand for the Resize Op. Below is the IR from the above linked issue
%304 = torch.operator "onnx.Resize"(%138, %none, %none, %303)
%303 = torch.operator "onnx.Concat"(%139, %302) # Sizes for the Resize Op
%302 = torch.operator "onnx.Cast"(%287)
%287 = torch.operator "onnx.Concat"(%283, %285)
and
%283 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__27> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
%285 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__28> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
__27: "0x080000000000000000000000",
__28: "0x080000000000000000000000",
I now understand from your comment that in ONNX protobuf, dim.dim_value=0 is just a placeholder, and dim.dim_param carries the true dynamic information. In that case, is there a way we can ensure produced %303 which is the sizes doesn't produce [0,0]?
I'd need a bit more context to help pin down the real crux of the model issue, but if the onnx model is doing 'onnx.Shape' on a dynamic tensor, we should definitely not be getting zeros.
The fact that there are literal constant zeros being passed to Resize seems like an export bug, unless I'm misunderstanding the 'onnx.Resize' op functionality.
understood, you are saying if the spatial dimensions of the proposed sizes here in the Resize Op https://github.com/HalfBloodPrince010/torch-mlir/blob/main/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp#L3138-L3167 is [0, 0], then that itself is a bug.
But, based on the above IR, the proposedSizes (i.e %303) is constructed from input pixel_values
input pixel_values [0, 0, 0, 0]
→ gather indices 2 & 3 (spatial dims)
→ divide by 16 (to match the patch grid)
→ concat
So, if the input spatial dims are 0, is there a way to enforce Resize Op's proposed sizes are not 0's?
More context, https://discord.com/channels/689900678990135345/1398431144210333818 let me know if I have missed anything?
What I'm saying is that the IR literally concatenates zeros for the proposed sizes. It is evidently not getting them from a 'onnx.Shape' op + gathers + divs or anything like that.
Thanks @zjgarvey Any pointers on where this shape propagation/or export could be going wrong or which passes would be useful to trace further would be super helpful.
I thought we got the 0s because of onnx exports for dynamic axes, which then got propagated.
I'm not exactly sure. I'd look at exporting a submodule for the problematic IR to get a smaller reproducer to start. Look at the model code in pytorch and the onnx graph torch exports. If that looks bad, it's likely an issue in torch or the model code.
If the torch exported onnx model looks good, but the mlir is bad, having a smaller e2e torch->onnx->torch-mlir reproducer to address will be helpful, since the original IR posted in the issue isn't enough to debug the full picture.
Some other questions: are we applying onnxruntime optimizations before export in this particular model? If something is going wrong there, that might be an indicator.
Are the sample model inputs a reasonable shape? If the sizes end up being <16 for the sample input, is something folding dim//16 to zero?