ao
ao copied to clipboard
Expected Tensor argument scales to have dtype torch.bfloat16, but got torch.float32 instead
Getting this error with int4
quantization.
May be a noob question: Is this a bug or does int4
require the weights to be in bfloat16
?
Traceback (most recent call last):
File "/home/agunapal/torch_ao/vit_ao.py", line 16, in <module>
quantize_(model, int4_weight_only())
File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/quantization/quant_api.py", line 463, in quantize_
_replace_with_custom_fn_if_matches_filter(
File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/quantization/quant_api.py", line 203, in _replace_with_custom_fn_if_matches_filter
new_child = _replace_with_custom_fn_if_matches_filter(
File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/quantization/quant_api.py", line 203, in _replace_with_custom_fn_if_matches_filter
new_child = _replace_with_custom_fn_if_matches_filter(
File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/quantization/quant_api.py", line 203, in _replace_with_custom_fn_if_matches_filter
new_child = _replace_with_custom_fn_if_matches_filter(
[Previous line repeated 2 more times]
File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/quantization/quant_api.py", line 199, in _replace_with_custom_fn_if_matches_filter
model = replacement_fn(model)
File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/quantization/quant_api.py", line 393, in insert_subclass
lin.weight = torch.nn.Parameter(constructor(lin.weight, **kwargs), requires_grad=requires_grad)
File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/quantization/quant_api.py", line 553, in apply_int4_weight_only_quant
return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=layout_type, use_hqq=use_hqq)
File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/dtypes/affine_quantized_tensor.py", line 286, in from_hp_to_intx
layout_tensor = layout_tensor_ctr(data, scale, zero_point, layout_type)
File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/dtypes/affine_quantized_tensor.py", line 1033, in from_plain
scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point)
File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/quantization/utils.py", line 319, in pack_tinygemm_scales_and_zeros
guard_dtype_size(scales, "scales", dtype=dtype, size=zeros.size())
File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/quantization/utils.py", line 128, in guard_dtype_size
raise ValueError(f"Expected Tensor argument {arg_name} to have dtype {dtype}, but got {tensor_arg.dtype} instead.")
ValueError: Expected Tensor argument scales to have dtype torch.bfloat16, but got torch.float32 instead.
Code for repro:
import torch
import torchao
from torchvision.models import vit_b_16, ViT_B_16_Weights
from torchao.utils import benchmark_model
from torchao.quantization import int4_weight_only, quantize_
torch.set_float32_matmul_precision('high')
dtype = torch.float32
device = "cuda"
N = 1
model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
model.eval()
quantize_(model, int4_weight_only())
model = torch.compile(model, mode='max-autotune').to(device).to(dtype)
method = "int8 quantize followed by compile"
input = (torch.randn(N, 3, 224, 224).to(device).to(dtype),)
with torch.no_grad():
# warmup
benchmark_model(model, 20, input)
# benchmark
result.append((method, N, benchmark_model(model, 100, input)))
for (method, N, elapsed_time) in result:
print(f"batch_size={N} : elapsed time {elapsed_time:.3f} ms : {method} ")