ao icon indicating copy to clipboard operation
ao copied to clipboard

Expected Tensor argument scales to have dtype torch.bfloat16, but got torch.float32 instead

Open agunapal opened this issue 5 months ago • 1 comments

Getting this error with int4 quantization.

May be a noob question: Is this a bug or does int4 require the weights to be in bfloat16?

Traceback (most recent call last):
  File "/home/agunapal/torch_ao/vit_ao.py", line 16, in <module>
    quantize_(model, int4_weight_only())
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/quantization/quant_api.py", line 463, in quantize_
    _replace_with_custom_fn_if_matches_filter(
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/quantization/quant_api.py", line 203, in _replace_with_custom_fn_if_matches_filter
    new_child = _replace_with_custom_fn_if_matches_filter(
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/quantization/quant_api.py", line 203, in _replace_with_custom_fn_if_matches_filter
    new_child = _replace_with_custom_fn_if_matches_filter(
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/quantization/quant_api.py", line 203, in _replace_with_custom_fn_if_matches_filter
    new_child = _replace_with_custom_fn_if_matches_filter(
  [Previous line repeated 2 more times]
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/quantization/quant_api.py", line 199, in _replace_with_custom_fn_if_matches_filter
    model = replacement_fn(model)
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/quantization/quant_api.py", line 393, in insert_subclass
    lin.weight = torch.nn.Parameter(constructor(lin.weight, **kwargs), requires_grad=requires_grad)
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/quantization/quant_api.py", line 553, in apply_int4_weight_only_quant
    return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=layout_type, use_hqq=use_hqq)
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/dtypes/affine_quantized_tensor.py", line 286, in from_hp_to_intx
    layout_tensor = layout_tensor_ctr(data, scale, zero_point, layout_type)
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/dtypes/affine_quantized_tensor.py", line 1033, in from_plain
    scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point)
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/quantization/utils.py", line 319, in pack_tinygemm_scales_and_zeros
    guard_dtype_size(scales, "scales", dtype=dtype, size=zeros.size())
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/quantization/utils.py", line 128, in guard_dtype_size
    raise ValueError(f"Expected Tensor argument {arg_name} to have dtype {dtype}, but got {tensor_arg.dtype} instead.")
ValueError: Expected Tensor argument scales to have dtype torch.bfloat16, but got torch.float32 instead.

Code for repro:

import torch
import torchao

from torchvision.models import vit_b_16, ViT_B_16_Weights 
from torchao.utils import benchmark_model
from torchao.quantization import int4_weight_only, quantize_

torch.set_float32_matmul_precision('high')

dtype  = torch.float32
device = "cuda"
N = 1

model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
model.eval()
quantize_(model, int4_weight_only())
model = torch.compile(model, mode='max-autotune').to(device).to(dtype)
method = "int8 quantize followed by compile"
input = (torch.randn(N, 3, 224, 224).to(device).to(dtype),)

with torch.no_grad():
    # warmup
    benchmark_model(model, 20, input)
    # benchmark
    result.append((method, N, benchmark_model(model, 100, input)))



for (method, N, elapsed_time) in result:
    print(f"batch_size={N} : elapsed time {elapsed_time:.3f} ms :  {method} ")

agunapal avatar Sep 11 '24 23:09 agunapal