torch.compile() the quantization method
Decorate the quantize() method with torch.compile. This uses much less GPU VRAM, avoiding OOM when quantizing Llama-3.1-405B (which previously OOMed on my machine even when quantizing one layer at a time).
@mobicham
Thank you @rationalism ! Added a few comments
@mobicham thank you! I don't see the comments though?
Oh, in the review, you don't see this https://github.com/mobiusml/hqq/pull/116/files/631ea011d8432b8a76518b0adc072574969d8771 ?
@mobicham I don't see any comments or review, no, sorry
I just tried this one and it compiles without graph breaks:
@torch.inference_mode()
def optimize_weights_proximal_legacy(
tensor: Tensor,
scale: Tensor,
zero: Tensor,
min_max: list,
axis: int = 0,
device: Union[str, None] = None,
opt_params: dict = {"lp_norm": 0.7, "beta": 1e1, "kappa": 1.01, "iters": 20},
verbose: bool = False,
) -> tuple:
if device is None:
device = tensor.device
else:
device = torch.device(device)
dtype = float16 if (device.type == "cuda") else float32
W_f = tensor.to(dtype=dtype, device=device)
scale = scale.to(dtype=dtype, device=device)
zero = zero.to(dtype=dtype, device=device)
# Params
lp_norm = torch.tensor(max(opt_params["lp_norm"], 0.1), dtype=dtype, device=device)
beta = torch.tensor(opt_params["beta"],dtype=dtype, device=device)
kappa = torch.tensor(opt_params["kappa"], dtype=dtype, device=device)
iters = opt_params["iters"]
best_error = torch.tensor(1e4, dtype=torch.float32, device=device)
for i in range(iters):
W_q = torch.round(W_f * scale + zero).clamp(min_max[0], min_max[1])
W_r = (W_q - zero) / scale
W_e = shrink_lp_op(W_f - W_r, beta, lp_norm)
zero = torch.mean(W_q - (W_f - W_e) * scale, axis=axis, keepdim=True)
beta *= kappa
current_error = torch.abs(W_f - W_r).mean().float()
if verbose:
print(i, np.round(current_error, 6))
if current_error < best_error:
best_error = current_error
else:
break
scale = scale.to(tensor.device)
zero = zero.to(tensor.device)
del W_f, W_q, W_r, W_e
torch.cuda.empty_cache()
W_q = torch.round(tensor * scale + zero).clamp(min_max[0], min_max[1])
return W_q, scale, zero
then
import torch
device = 'cuda:0'
backend = 'torchao_int4' #"torchao_int4" (4-bit only) or "bitblas" (4-bit + 2-bit)
compute_dtype = torch.float16 if backend=="bitblas" else torch.bfloat16
cache_dir = '.'
model_id = 'meta-llama/Meta-Llama-3-8B-Instruct'
########################################################################
#Load model
from transformers import AutoModelForCausalLM, AutoTokenizer
from hqq.models.hf.base import AutoHQQHFModel
from hqq.core.quantize import *
#Load
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_dir)
model = AutoModelForCausalLM.from_pretrained(model_id, cache_dir=cache_dir, torch_dtype=compute_dtype, attn_implementation="sdpa")
#Quantize
#torch._dynamo.config.capture_scalar_outputs = True
Quantizer.optimize_weights = torch.compile(Quantizer.optimize_weights)
Quantizer.quantize = torch.compile(Quantizer.quantize)
quant_config = BaseQuantizeConfig(nbits=4, group_size=64, axis=1)
AutoHQQHFModel.quantize_model(model, quant_config=quant_config, compute_dtype=compute_dtype, device=device)
I wil ltake a look at Quantizer.quantize compilation now
Works with Quantizer.quantize compiled as well! I suggest we do the following:
-We remove the @torch.compile decorator and do the compilation outside, either like this
Quantizer.optimize_weights = torch.compile(Quantizer.optimize_weights)
or like this
Quantizer.quantize = torch.compile(Quantizer.quantize)
This way if the compilation breaks the user can simply skip this step.
Then in optimize_weights..., we need to move the parameters after dettermining the dtype, this way they have the exact same dtype as the other tensors:
....
dtype = float16 if (device.type == "cuda") else float32
W_f = tensor.to(dtype=dtype, device=device)
scale = scale.to(dtype=dtype, device=device)
zero = zero.to(dtype=dtype, device=device)
# Params
lp_norm = torch.tensor(max(opt_params["lp_norm"], 0.1), dtype=dtype, device=device)
beta = torch.tensor(opt_params["beta"],dtype=dtype, device=device)
kappa = torch.tensor(opt_params["kappa"], dtype=dtype, device=device)
iters = opt_params["iters"]
....
Ideally we do this for the optimize_weights_proximal_v2 version as well.
:pray:
Closing this PR since there was no update for over a year, but feel free to re-open another one please!