onnxruntime
onnxruntime copied to clipboard
Error when using Resize layers in quantized QNN models
Describe the issue
When using pytorch nn.Resize the exported ONNX model works fine but when I quantize the ONNX model for QNN I get the following error:
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Non-zero status code returned while running Resize node. Name:'/upsample/Resize' Status Message: upsamplebase.h:439 onnxruntime::UpsampleBase::ScalesValidation 'Linear' mode only supports:
* 2-D inputs or
* 3-D inputs ('Bilinear', 'Trilinear') or
* 4-D inputs with the corresponding outermost 2 scale values being 1 or the corresponding outermost and innermost scale values being 1 or
* 5-D inputs with the corresponding outermost 2 scale values being 1in the Resize operator
To reproduce
- install
onnxruntime==1.22
torch ==2.7.1
onnx==1.18.0
- Run
"""Minimal example of to reproduce the issue with nn.Upsample in qnn quantization."""
import numpy as np
import onnx
import onnxruntime
from onnxruntime.quantization import quantize, QuantType
from onnxruntime.quantization.execution_providers.qnn import get_qnn_qdq_config, qnn_preprocess_model
import torch
import torch.nn as nn
# Define a simple model with an upsample layer
class UpsampleModel(nn.Module):
"""A simple model that contains an upsample layer."""
def __init__(self):
super(UpsampleModel, self).__init__()
self.upsample = nn.Upsample(scale_factor=2, mode="bilinear")
def forward(self, x):
return self.upsample(x)
class DummyCalibrationDataReader:
"""A dummy calibration data reader that generates random data for testing."""
def __init__(self):
# Create dummy calibration data
self.data = [np.random.rand(1, 3, 4, 4).astype(np.float32) for _ in range(10)]
self.index = 0
def get_next(self):
if self.index < len(self.data):
calibration_data = {"input": self.data[self.index]}
self.index += 1
return calibration_data
return None
def rewind(self):
self.index = 0
def main():
# Export the model to ONNX
model = UpsampleModel()
input_tensor = torch.randn(1, 3, 4, 4) # Example input tensor
onnx_file_path = "upsample_model.onnx"
torch.onnx.export(
model,
input_tensor,
onnx_file_path,
opset_version=20,
input_names=["input"], # Explicitly specify input name
output_names=["output"] # (optional) specify output name
)
# Load the ONNX model
onnx_model = onnx.load(onnx_file_path)
onnx.checker.check_model(onnx_model)
# Test the ONNX model
ort_session = onnxruntime.InferenceSession(onnx_file_path)
ort_inputs = {ort_session.get_inputs()[0].name: input_tensor.numpy()}
ort_outs = ort_session.run(None, ort_inputs)
# Verify the output
model_output = model(input_tensor).detach().numpy()
assert np.allclose(model_output, ort_outs[0], atol=1e-6), "ONNX output does not match PyTorch output"
print("Upsample layer test passed!")
# Preprocess the model for quantization
preprocessed_model_path = "preprocessed_upsample_model.onnx"
changed = qnn_preprocess_model(onnx_file_path, preprocessed_model_path)
model_to_quantize = preprocessed_model_path if changed else onnx_file_path
# Instantiate the dummy calibration data reader
calibration_data_reader = DummyCalibrationDataReader()
# Get QNN QDQ configuration
qnn_config = get_qnn_qdq_config(
model_to_quantize,
calibration_data_reader=calibration_data_reader, # Replace with actual calibration data reader if available
activation_type=QuantType.QUInt16,
weight_type=QuantType.QUInt16,
)
# Quantize the model
quantized_model_path = "quantized_upsample_model.onnx"
quantize(model_to_quantize, quantized_model_path, qnn_config)
print("Quantization completed. Quantized model saved at:", quantized_model_path)
# run the inference on the quantized model
ort_session_quantized = onnxruntime.InferenceSession(quantized_model_path)
ort_inputs_quantized = {ort_session_quantized.get_inputs()[0].name: input_tensor.numpy()}
ort_outs_quantized = ort_session_quantized.run(None, ort_inputs_quantized)
#
# 2025-06-18 18:25:30.8555744 [E:onnxruntime:, sequential_executor.cc:516 onnxruntime::ExecuteKernel]
# Non-zero status code returned while running Resize node. Name:'/upsample/Resize'
# Status Message: upsamplebase.h:439 onnxruntime::UpsampleBase::ScalesValidation
# 'Linear' mode only supports:
# * 2-D inputs or
# * 3-D inputs ('Bilinear', 'Trilinear') or
# * 4-D inputs with the corresponding outermost 2 scale values being 1 or the corresponding
# outermost and innermost scale values being 1 or
# * 5-D inputs with the corresponding outermost 2 scale values being 1in the Resize operator
# Verify the output of the quantized model
assert np.allclose(
model_output, ort_outs_quantized[0], atol=1e-6
), "Quantized ONNX output does not match PyTorch output"
if __name__ == "__main__":
main()
Urgency
No response
Platform
Windows
OS Version
Windows 11
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
1.22
ONNX Runtime API
Python
Architecture
X64
Execution Provider
Default CPU
Execution Provider Library Version
No response
When inspecting with netron the quantized model I get the following graph:
It is a bit surprizing to me that the scale gets quantized. Maybe the error s is due to the fact that the two leading ones are not exactly equal to one anymore after quantization and dequantization.
Using self.upsample = nn.Upsample(size=[8,8], mode="bilinear") instead solved the problem. The graph I then obtain is
Applying stale label due to no activity in 30 days
Applying stale label due to no activity in 30 days
Closing issue due to no activity in 30 days