[Quant] Can quant not be decomposed on inductor?
torch.ops.torchao.dequantize_affine decomposed to convert_element_type and mul. Inductor will do constant_fold before pattern matching On constant_fold, inductor replace fp8 weight and some previous operations with fp32 weight Is this as expected?
Now register_decomposition on register_decomposition
This sample test can reproduce the issue
import os
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["TORCHINDUCTOR_FREEZING"] = "1"
os.environ["TORCH_COMPILE_DEBUG"] = "0"
os.environ["TORCHDYNAMO_PRINT_GUARD_FAILS"] = "0"
from typing import Callable, List, Optional, Union
import torch
from torch import nn
import torchao
#import torchao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq
def dequantize_per_tensor(
input: torch.Tensor,
scale: torch.Tensor,
output_dtype: torch.dtype
) -> torch.Tensor:
res = torch.ops.torchao.dequantize_affine(
input=input,
block_size=input.shape,
scale=scale,
zero_point=torch.tensor(0),
input_dtype=torch.float8_e4m3fn,
)
if output_dtype != torch.float:
res = res.to(output_dtype)
return res
def quantize_per_tensor(
input: torch.Tensor,
scale: torch.Tensor,
) -> torch.Tensor:
return torch.ops.torchao.quantize_affine(
input=input,
block_size=input.shape,
scale=scale,
zero_point=torch.tensor(0),
output_dtype=torch.float8_e4m3fn,
)
class Perceptron(torch.nn.Module):
def __init__(
self,
in_size: int,
out_size: int,
bias: bool = True,
activation: Union[
torch.nn.Module,
Callable[[torch.Tensor], torch.Tensor],
] = torch.relu,
device: Optional[torch.device] = None,
dtype: torch.dtype = torch.float32,
) -> None:
super().__init__()
self._out_size = out_size
self._in_size = in_size
self._linear: nn.Linear = nn.Linear(
self._in_size,
self._out_size,
bias=bias,
device=device,
dtype=dtype,
)
self._activation_fn: Callable[[torch.Tensor], torch.Tensor] = activation
def forward(self, input: torch.Tensor) -> torch.Tensor:
return self._activation_fn(self._linear(input))
class MLP(torch.nn.Module):
def __init__(
self,
in_size: int,
layer_sizes: List[int],
bias: bool = True,
activation: Union[
str,
Callable[[], torch.nn.Module],
torch.nn.Module,
Callable[[torch.Tensor], torch.Tensor],
] = torch.relu,
device: Optional[torch.device] = None,
dtype: torch.dtype = torch.float32,
) -> None:
super().__init__()
if activation == "relu":
activation = torch.relu
elif activation == "sigmoid":
activation = torch.sigmoid
if not isinstance(activation, str):
self._mlp: torch.nn.Module = torch.nn.Sequential(
*[
Perceptron(
layer_sizes[i - 1] if i > 0 else in_size,
layer_sizes[i],
bias=bias,
activation=activation,
device=device,
dtype=dtype,
)
for i in range(len(layer_sizes))
]
)
else:
assert (
ValueError
), "This MLP only support str version activation function of relu, sigmoid, and swish_layernorm"
def forward(self, input: torch.Tensor) -> torch.Tensor:
return self._mlp(input)
class DenseArch(nn.Module):
def __init__(
self,
in_features: int,
layer_sizes: List[int],
device: Optional[torch.device] = None,
) -> None:
super().__init__()
self.model: nn.Module = MLP(
in_features, layer_sizes, bias=True, activation="relu", device=device
)
def forward(self, features: torch.Tensor) -> torch.Tensor:
return self.model(features)
def inc_convert(model, dtype):
model.eval()
qtype = torch.float8_e4m3fn
#from torch.ao.quantization.fx._decomposed import quantize_per_tensor, dequantize_per_tensor
from torch.nn import functional as F
class FP8QDQLinear(torch.nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.weight = torch.empty((out_features, in_features),)
self.weight_scale = None
self.scale = None
self.bias = None
def forward(self, input):
weight = dequantize_per_tensor(
self.weight.data,
self.weight_scale,
dtype,
)
q_input = quantize_per_tensor(
input,
self.scale,
)
dq_input = dequantize_per_tensor(
q_input,
self.scale,
dtype
)
# out1 = torch._scaled_mm(q_input, self.weight.T, torch.tensor(self.scale), torch.tensor(self.weight_scale), bias=self.bias, out_dtype=torch.float8_e4m3fn)
# out2 = torch.mm(dq_input, weight.T) + self.bias
# out3 = torch.nn.functional.linear(dq_input, weight, self.bias)
out = torch.nn.functional.linear(dq_input, weight, self.bias)
return out
class FP8QDQEmbeddingBag(torch.nn.Module):
def __init__(self, weight_shape, max_norm, norm_type, scale_grad_by_freq, mode, sparse,
include_last_offset, padding_idx):
super().__init__()
#self.mod = mod
self.max_norm = max_norm
self.norm_type = norm_type
self.scale_grad_by_freq = scale_grad_by_freq
self.mode = mode
self.sparse = sparse
self.include_last_offset = include_last_offset
self.padding_idx = padding_idx
self.weight = torch.empty(weight_shape)
self.weight_scale = None
def forward(
self,
input,
offsets=None,
per_sample_weights=None,
):
weight = dequantize_per_tensor(
self.weight.data,
self.weight_scale,
dtype,
)
return F.embedding_bag(
input,
weight,
offsets,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.mode,
self.sparse,
per_sample_weights,
self.include_last_offset,
self.padding_idx,
)
hook_handles = []
import json
from collections import namedtuple
def generate_model_info(model):
mod_inst_info = namedtuple("ModInstInfo", ["name", "parent"])
parent_child_mod_dict = {}
def create_mod_info_recursion(parent):
for name, mod in parent.named_children():
parent_child_mod_dict[mod] = mod_inst_info(name=name, parent=parent)
create_mod_info_recursion(mod)
create_mod_info_recursion(model)
return parent_child_mod_dict
parent_child_mod_dict = generate_model_info(model)
with torch.no_grad():
for i, (name, mod) in enumerate(model.named_modules()):
mod_type_str = mod.__class__.__name__
#print(mod_type_str)
#continue
if mod_type_str not in ["Linear", "EmbeddingBag"]:
continue
print(mod_type_str, name)
param = mod.weight
xmax = torch.max(param)
weight_scale = xmax / torch.finfo(qtype).max
setattr(mod, "weight_scale", weight_scale)
q_param = torch.clamp((param / weight_scale), torch.finfo(qtype).min, torch.finfo(qtype).max).to(qtype)
mod.weight.data = q_param
if mod_type_str in ["Linear"]:
scale = [1 / torch.finfo(qtype).max]
assert len(scale) == 1
#setattr(mod, "scale", scale[0])
patched_mod = FP8QDQLinear(mod.in_features, mod.out_features)
patched_mod.bias = mod.bias
patched_mod.weight.data = q_param
patched_mod.scale = torch.tensor(scale[0])
patched_mod.weight_scale = torch.tensor(weight_scale.item())
else:
patched_mod = FP8QDQEmbeddingBag(
weight_shape=mod.weight.shape,
max_norm=mod.max_norm,
norm_type=mod.norm_type,
scale_grad_by_freq=mod.scale_grad_by_freq,
mode=mod.mode,
sparse=mod.sparse,
include_last_offset=mod.include_last_offset,
padding_idx=mod.padding_idx)
patched_mod.weight_scale = weight_scale.item()
patched_mod.weight.data = q_param
parent = parent_child_mod_dict[mod].parent
name = parent_child_mod_dict[mod].name
setattr(parent, name, patched_mod)
def pt2e(model, inputs):
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
from torch.export import export_for_training
with torch.no_grad():
out = model(*inputs)
exported_model = export_for_training(
model,
example_inputs,
strict=True
).module()
quantizer = X86InductorQuantizer()
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config())
prepared_model = prepare_pt2e(exported_model, quantizer)
prepared_model(*inputs)
converted_model = convert_pt2e(prepared_model)
torch.ao.quantization.move_exported_model_to_eval(converted_model)
converted_model(*inputs)
return converted_model
import time
from torch._inductor import config as inductor_config
from torch._dynamo import config
config.error_on_recompile = True
#inductor_config.cpp_wrapper = True
inductor_config.max_autotune = False
inductor_config.freezing = True
inductor_config.aot_inductor.debug_compile = False
model = DenseArch(13,[512,256,128], "cpu")
example_inputs = (torch.randn(128, 13),)
print(model)
tmp0 = model.model._mlp[0]._linear(*example_inputs)
# tmp1 = model.model._mlp[0]._linear(*example_inputs)
import contextlib
ctx1 = contextlib.suppress()
ctx2 = torch.autocast("cpu", enabled=True, dtype=torch.bfloat16)
#dtype = torch.float
dtype = torch.float32
if dtype == torch.float32:
ctx = ctx1
else:
ctx = ctx2
with torch.no_grad(), ctx:
qtype = torch.float8_e4m3fn
refe = model(*example_inputs)
if qtype == torch.int8:
model = pt2e(model, example_inputs)
else:
inc_convert(model, dtype)
test_eager = model(*example_inputs)
model = torch.compile(model)
model(*example_inputs)
test = model(*example_inputs)
yeah we use https://github.com/pytorch/ao/blob/96aec6a3e713687c1728a20a08d5c54db0344377/torchao/utils.py#L180 to prevent the op of being decomposed during export, but continue to be decomposed in inductor
do you want the op to be preserved in inductor?
yeah we use
Line 180 in 96aec6a
def _register_custom_op(lib): to prevent the op of being decomposed during export, but continue to be decomposed in inductor do you want the op to be preserved in inductor?
Yes. There is an issue that fp8 weight will be fixed to fp32 weight on constant_fold. Or do we have any other way to avoid this issue?
And quant/dequant decomposition will make the pattern complicated. Can we not decompose here?
Hi @jerryzh168 , I'm not sure if removing decompose here would cause any other issues. Can we consider landing in pytorch first and then migrating over? My PR on PT: https://github.com/pytorch/pytorch/pull/153602 https://github.com/pytorch/pytorch/pull/153601
Hi @jerryzh168 , do you have any suggestions?
Hi @jerryzh168 Please let me explain the whole story.
What we want to do now is to enable FP8 quantization in PyTorch. And similar as INT8 quantization, we need to insert quantize and dequantize ops into the graph.
However we met problems with these q/dq ops both in the PyTorch core and Torchao.
PyTorch core:
- The
quantize_per_tensorop does not support FP8. We want to fix it via https://github.com/pytorch/pytorch/pull/153601. And as you commented, the op is deprecated.
Torchao:
- In the fusion pass in Inductor, we want to match the pattern
fp8_weight -> torchao.dequantize_affine_float8 -> fp32_opand fuse it asfp8_weight -> weight_pack -> fp8_op. We have done so for INT8 PT2E quantization. However, the pattern matching pass is applied after a constant folding pass in Inductor: https://github.com/pytorch/pytorch/blob/100ec0b34aeff2b948dae33937857d0c86cf1646/torch/_inductor/fx_passes/freezing_patterns.py#L69C1-L74C1 Afterconstant_fold(gm), the pattern will be folded asfp32_weight -> fp32_op. Then the original pattern cannot be found any more and the FP8 semantics is lost since the pattern is entirely in fp32 now. - For INT8, the
int8_weight -> quantized_decomposed.dequantize_per_channel -> fp32_oppattern won't be folded because we markquantized_decomposed.dequantize_per_channelimpure so that it won't be folded: https://github.com/pytorch/pytorch/blob/100ec0b34aeff2b948dae33937857d0c86cf1646/torch/_inductor/constant_folding.py#L139C1-L149C1 . But for thetorchao.dequantize_affine_float8, we cannot do this because- It is an op from Torchao, which is unknown to the constant folder
- It is decomposed to smaller ops, so we cannot put it in the list as a single op.
So, we think an easy and short-term solution is to modify the ops in PyTorch core via https://github.com/pytorch/pytorch/pull/153601. However, if we want to resolve the issue with Torchao, we need to
- Add a method in the constant folder in Inductor to allow registration of impure ops
- Avoid decomposition of
torchao.dequantize_affine_float8and register this op as impure so that it won't be constant-folded.
Do you think the short term solution makes sense? And for the solution with Torchao, do you have more comments or concerns? We are looking forward to your suggestions. Thanks.
@Xia-Weiwen thanks for the clear summary.
I have duplicated the constant_fold code in torchao: https://github.com/pytorch/ao/blob/60d63a637f5091d7c6917b3c28bca98540136600/torchao/quantization/pt2e/quantize_pt2e.py#L12, would it be enough for you to add torchao.dequantize_affine_float8 there?
I agree that for the longer term, inductor should allow registration for impure ops, cc @eellison @jansel
for Avoid decomposition of torchao.dequantize_affine_float8 I think this is not done before, in INT8 path we explicitly decompose it for inductor right? what changed for float8?
Is dequantize impure? What is it mutating?
IMO this op should be decomposed in inductor. You can register the decomp in the same place the op is defined.
@jansel technically it's not, but we may need to preserve dequantize op so it can be fused with other ops to become a quantized op that takes integer tensor as input. is there a different way to specify this?
Impure isn't what you are looking for. Impure means the op mutates one of its inputs, so when we functionalize we need to introduce more copies (which might increase memory usage if inductor cant optimize the copies away).
Ops will be preserved if you don't write a decomp for them, which forces them to be ExternKernels and prevents fusion with other ops.
Ops will be preserved if you don't write a decomp for them, which forces them to be ExternKernels and prevents fusion with other ops.
what about for constant folding? what prevents an op to be constant folded (except for marking them as impure)? I think that's the original reason we marked these ops as impure
I don't believe we have a dont-constant-fold flag (correct me if I'm wrong @eellison ), though maybe we should.
Thanks for your replies.
I have duplicated the constant_fold code in torchao:
@jerryzh168 If I understand correctly, the duplicate code is used in convert_pt2e in Torchao. However, what we talked about was the constant-folding pass in Inductor here: https://github.com/pytorch/pytorch/blob/100ec0b34aeff2b948dae33937857d0c86cf1646/torch/_inductor/fx_passes/freezing_patterns.py#L69. So, I don't think we can add something in the Torchao code and resolve the issue.
We don't want the dequantize op decomposed because once it's decomposed, the op is gone and it becomes difficult to tell the constant folder not to fold such patterns. What do you think? Thanks.
@jansel There are patterns like constant_quantized_weight -> dequantize -> fp32_op -> ... in the quantization scenario. And during a lowering process by torch.compile, we want to fuse such patterns to constant_quantized_weight -> quantized_op -> .... The fusion is done via https://github.com/pytorch/pytorch/blob/100ec0b34aeff2b948dae33937857d0c86cf1646/torch/_inductor/fx_passes/freezing_patterns.py#L72C1-L75C1. However, the constant folding pass is applied before the fusion pass: https://github.com/pytorch/pytorch/blob/100ec0b34aeff2b948dae33937857d0c86cf1646/torch/_inductor/fx_passes/freezing_patterns.py#L69. So, we need some mechanism to avoid the quantization pattern being folded to constant_fp32_weight -> fp32_op, otherwise the fusion for quantization won't be applied and the quantization semantics are lost. Do you have any suggestions? Thanks.
@Xia-Weiwen
We don't want the dequantize op decomposed because once it's decomposed, the op is gone and it becomes difficult to tell the constant folder not to fold such patterns. What do you think? Thanks.
this makes sense, how does it work before? also as Jason mentioned if you don't register decomposition for it, it won't be decomposed, maybe we could try adding an option to skip the registration here: https://github.com/pytorch/ao/blob/4d5f65711f7c53985d09e3a8c6aa8d8549f7d5a4/torchao/utils.py#L228 do you want to test this out with the new affine quant ops?
I will do it. Plan to
- create _dont_constant_fold on PT
- add option to skip decompose on ao