xla
xla copied to clipboard
Failed to lower QParam + DeQuant QDQ pair to StableHLO
🐛 Bug
When exporting a pt2e quantized model to StableHLO, I got this error:
error: 'mhlo.uniform_dequantize' op operand #0 must be tensor of 4/8/16/32-bit uniform quantized signed integer or 4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<32x3x3x3xi8>'
As far as I can tell, the QDQ converter patch introduced in #5763 only handles the Quant
+ DeQuant
QDQ pair, but not QParam
+ DeQuant
QDQ pair, thus the HLO -> MHLO conversion would fail.
As seen in the following pt2e quantized model visualization, the Conv2d weights is a quantized param with i8
dtype, followed by a DeQuant
OP:
To Reproduce
Here is the example code to reproduce the bug:
import os
import torch
import torch.export
from torch import nn
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import (
prepare_pt2e,
convert_pt2e,
)
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
XNNPACKQuantizer,
get_symmetric_quantization_config,
)
from torch.fx.passes.graph_drawer import FxGraphDrawer
from torch_xla import stablehlo
import torch_xla.core.xla_model as xm
class TestModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv_0 = nn.Conv2d(
in_channels=3,
out_channels=32,
kernel_size=(3, 3),
padding=(1, 1),
bias=False
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
x = self.conv_0(input)
out = torch.sigmoid(x)
return out
if __name__ == "__main__":
device = xm.xla_device()
model = TestModel().eval()
input_tensor = torch.randn(size=(1, 3, 64, 64), dtype=torch.float32)
sample_inputs = (input_tensor, )
model = capture_pre_autograd_graph(model, sample_inputs)
# Insert PTQ observers.
quant_config = get_symmetric_quantization_config()
quantizer = XNNPACKQuantizer().set_global(quant_config)
model_prepared = prepare_pt2e(model, quantizer)
# Do calibration.
# Quantize the model.
model_quant = convert_pt2e(model_prepared, fold_quantize=True)
model_quant_exported = torch.export.export(model_quant, sample_inputs)
# Visualize the quantized model.
model_quant_exported.graph.print_tabular()
drawer = FxGraphDrawer(model_quant_exported, "model_qdq")
with open(f"{drawer._name}.svg", mode="wb") as f:
f.write(drawer.get_dot_graph().create_svg())
# Convert to stablehlo.
print(f"[================ PID: {os.getpid()} ================]")
model_stablehlo = stablehlo.exported_program_to_stablehlo(
model_quant_exported
)
model_stablehlo_str = model_stablehlo.get_stablehlo_text()
print(model_stablehlo_str)
The XLA log containing the lowered HLO:
Execution Analysis: ================================================================================
2024-02-19 16:44:53.339801: I torch_xla/csrc/runtime/pjrt_computation_client.cc:550] Executing PjRt computation on CPU:0
2024-02-19 16:44:53.339833: I external/xla/xla/pjrt/cpu/cpu_client.cc:1641] ExecuteShard executes computation SyncTensorsGraph.4 on assigned replica/partition on device TFRT_CPU_0
2024-02-19 16:44:53.340033: I torch_xla/csrc/runtime/pjrt_computation_client.cc:605] Returning 2 results
loc("custom-call.2"): error: 'mhlo.uniform_dequantize' op operand #0 must be tensor of 4/8/16/32-bit uniform quantized signed integer or 4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<32x3x3x3xi8>'
2024-02-19 16:45:01.140535: I torch_xla/csrc/runtime/tf_logging.cc:12] Check failed: status.ok()
*** Begin stack trace ***
tsl::CurrentStackTrace()
torch_xla::ConvertHloToStableHlo(xla::HloModuleProto const*, mlir::ModuleOp*)
torch_xla::hloToStablehlo(xla::HloModuleProto const*, bool)
torch_xla::DumpUtil::ToHlo(c10::ArrayRef<torch::lazy::Value>, torch::lazy::BackendDevice const&, torch_xla::EmitMode)
torch_xla::XLAGraphExecutor::DumpHloComputation(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > > const&, torch_xla::EmitMode)
...
__libc_start_main
_start
*** End stack trace ***
HLO -> MHLO conversion failed.
MHLO Module from HLO -> MHLO conversion is not legal.Please open a github issue to PyTorch/XLA.
Original HLO dump:
HloModule IrToHlo.11, entry_computation_layout={(s8[32,3,3,3]{3,2,1,0}, f32[1,3,64,64]{3,2,1,0})->(f32[1,32,64,64]{3,2,1,0})}
ENTRY %IrToHlo.11 (p0.1: s8[32,3,3,3], p1.3: f32[1,3,64,64]) -> (f32[1,32,64,64]) {
%p1.3 = f32[1,3,64,64]{3,2,1,0} parameter(1)
%custom-call.4 = s8[1,3,64,64]{3,2,1,0} custom-call(f32[1,3,64,64]{3,2,1,0} %p1.3), custom_call_target="stablehlo.uniform_quantize", api_version=API_VERSION_TYPED_FFI, backend_config={scale=[1.00],zero_point=[0],storage_type=si8,expressed_type=f32,storage_min=-128,storage_max=127}
%custom-call.5 = f32[1,3,64,64]{3,2,1,0} custom-call(s8[1,3,64,64]{3,2,1,0} %custom-call.4), custom_call_target="stablehlo.uniform_dequantize", api_version=API_VERSION_TYPED_FFI, backend_config={scale=[1.00],zero_point=[0],storage_type=si8,expressed_type=f32,storage_min=-128,storage_max=127}
%p0.1 = s8[32,3,3,3]{3,2,1,0} parameter(0)
%custom-call.2 = f32[32,3,3,3]{3,2,1,0} custom-call(s8[32,3,3,3]{3,2,1,0} %p0.1), custom_call_target="stablehlo.uniform_dequantize", api_version=API_VERSION_TYPED_FFI, backend_config={scale=[1.00],zero_point=[0],storage_type=si8,expressed_type=f32,storage_min=-127,storage_max=127}
%convolution.6 = f32[1,32,64,64]{3,2,1,0} convolution(f32[1,3,64,64]{3,2,1,0} %custom-call.5, f32[32,3,3,3]{3,2,1,0} %custom-call.2), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_oi01->bf01
%custom-call.7 = s8[1,32,64,64]{3,2,1,0} custom-call(f32[1,32,64,64]{3,2,1,0} %convolution.6), custom_call_target="stablehlo.uniform_quantize", api_version=API_VERSION_TYPED_FFI, backend_config={scale=[1.00],zero_point=[0],storage_type=si8,expressed_type=f32,storage_min=-128,storage_max=127}
%custom-call.8 = f32[1,32,64,64]{3,2,1,0} custom-call(s8[1,32,64,64]{3,2,1,0} %custom-call.7), custom_call_target="stablehlo.uniform_dequantize", api_version=API_VERSION_TYPED_FFI, backend_config={scale=[1.00],zero_point=[0],storage_type=si8,expressed_type=f32,storage_min=-128,storage_max=127}
%logistic.9 = f32[1,32,64,64]{3,2,1,0} logistic(f32[1,32,64,64]{3,2,1,0} %custom-call.8)
ROOT %tuple.10 = (f32[1,32,64,64]{3,2,1,0}) tuple(f32[1,32,64,64]{3,2,1,0} %logistic.9)
}
Environment
- Reproducible on XLA backend [CPU]:
- pytorch version: 2.3.0.dev20240218+cpu
- torch_xla version: f4971a7
Thanks for raising this issue @Nullkooland. @lsy323 can you please have a look?
Hi @Nullkooland, thank you for reporting the issue!
Upstream introduced a BC breaking change, in which the fp->quant
pair will be folded by default. As you mentioned Qparam
+ DeQuant
is not supported now, so you'll need to set fold_quantize=False
. The repro script passed on my end after disabling the quant weight folding.
Also please don't forget to set STABLEHLO_BYTECODE_FROM_PRETTYPRINT=1
(As mentioned in https://github.com/pytorch/xla/pull/5763), until the related StableHLO issue is resolved.
Hi @Nullkooland, thank you for reporting the issue!
Upstream introduced a BC breaking change, in which the
fp->quant
pair will be folded by default. As you mentionedQparam
+DeQuant
is not supported now, so you'll need to setfold_quantize=False
. The repro script passed on my end after disabling the quant weight folding.Also please don't forget to set
STABLEHLO_BYTECODE_FROM_PRETTYPRINT=1
(As mentioned in #5763), until the related StableHLO issue is resolved.
@lsy323 Thanks for your reply!
Will the QParam
+ DeQuant
QDQ case be supported in the future?
So that the exported StableHLO would look like:
module @ExampleQDQModel {
func.func @main(
%weight_0_q: tensor<32x3x3x3x!quant.uniform<i8:f32, 1.000000e+00>>,
// Folded QParams.
%input: tensor<1x3x64x64xf32>
) {
%input_q = stablehlo.uniform_quantize %input : (tensor<1x3x64x64xf32>) -> tensor<1x3x64x64x!quant.uniform<i8:f32, 1.000000e+00>>
%input_dq = stablehlo.uniform_dequantize %input_q : (tensor<1x3x64x64x!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<1x3x64x64xf32>
%weight_0_dq = stablehlo.uniform_dequantize %weight_0_q : (tensor<32x3x3x3x!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<32x3x3x3xf32>
%conv_0 = stablehlo.convolution(%input_dq, %weight_0_dq) {...} : (tensor<1x3x64x64xf32>, tensor<32x3x3x3xf32>) -> tensor<1x32x64x64xf32>
%conv_0_q = stablehlo.uniform_quantize %conv_0 : (tensor<1x3x64x64xf32>) -> tensor<1x3x64x64x!quant.uniform<i8:f32, 1.000000e+00>>
%conv_0_dq = stablehlo.uniform_dequantize %conv_0_q : (tensor<1x3x64x64x!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<1x3x64x64xf32>
...
}
}
so the quantized params could be exported directly to reduce the size of the exported artifacts.
Hi @Nullkooland, thank you for reporting the issue!
Upstream introduced a BC breaking change, in which the
fp->quant
pair will be folded by default. As you mentionedQparam
+DeQuant
is not supported now, so you'll need to setfold_quantize=False
. The repro script passed on my end after disabling the quant weight folding.Also please don't forget to set
STABLEHLO_BYTECODE_FROM_PRETTYPRINT=1
(As mentioned in #5763), until the related StableHLO issue is resolved.
Hi @Nullkooland, just FYI STABLEHLO_BYTECODE_FROM_PRETTYPRINT=1
is not needed anymore now.
Hi @Nullkooland, thank you for reporting the issue! Upstream introduced a BC breaking change, in which the
fp->quant
pair will be folded by default. As you mentionedQparam
+DeQuant
is not supported now, so you'll need to setfold_quantize=False
. The repro script passed on my end after disabling the quant weight folding. Also please don't forget to setSTABLEHLO_BYTECODE_FROM_PRETTYPRINT=1
(As mentioned in #5763), until the related StableHLO issue is resolved.@lsy323 Thanks for your reply!
Will the
QParam
+DeQuant
QDQ case be supported in the future? So that the exported StableHLO would look like:module @ExampleQDQModel { func.func @main( %weight_0_q: tensor<32x3x3x3x!quant.uniform<i8:f32, 1.000000e+00>>, // Folded QParams. %input: tensor<1x3x64x64xf32> ) { %input_q = stablehlo.uniform_quantize %input : (tensor<1x3x64x64xf32>) -> tensor<1x3x64x64x!quant.uniform<i8:f32, 1.000000e+00>> %input_dq = stablehlo.uniform_dequantize %input_q : (tensor<1x3x64x64x!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<1x3x64x64xf32> %weight_0_dq = stablehlo.uniform_dequantize %weight_0_q : (tensor<32x3x3x3x!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<32x3x3x3xf32> %conv_0 = stablehlo.convolution(%input_dq, %weight_0_dq) {...} : (tensor<1x3x64x64xf32>, tensor<32x3x3x3xf32>) -> tensor<1x32x64x64xf32> %conv_0_q = stablehlo.uniform_quantize %conv_0 : (tensor<1x3x64x64xf32>) -> tensor<1x3x64x64x!quant.uniform<i8:f32, 1.000000e+00>> %conv_0_dq = stablehlo.uniform_dequantize %conv_0_q : (tensor<1x3x64x64x!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<1x3x64x64xf32> ... } }
so the quantized params could be exported directly to reduce the size of the exported artifacts.
For this question, I'm not sure, I'll leave to StableHLO team member to provide some inputs cc @sdasgup3
Will the QParam + DeQuant QDQ case be supported in the future? So that the exported StableHLO would look like:
Yes, we had plans on achieving the outcome. There may be different paths to achieve the goal, like doing pattern matching at Aten level vs StableHLO level, which we are still exploring.