hqq icon indicating copy to clipboard operation
hqq copied to clipboard

torch.compile() the quantization method

Open rationalism opened this issue 1 year ago • 6 comments

Decorate the quantize() method with torch.compile. This uses much less GPU VRAM, avoiding OOM when quantizing Llama-3.1-405B (which previously OOMed on my machine even when quantizing one layer at a time).

@mobicham

rationalism avatar Sep 10 '24 22:09 rationalism

Thank you @rationalism ! Added a few comments

mobicham avatar Sep 11 '24 10:09 mobicham

@mobicham thank you! I don't see the comments though?

rationalism avatar Sep 12 '24 02:09 rationalism

Oh, in the review, you don't see this https://github.com/mobiusml/hqq/pull/116/files/631ea011d8432b8a76518b0adc072574969d8771 ?

mobicham avatar Sep 12 '24 07:09 mobicham

@mobicham I don't see any comments or review, no, sorry

rationalism avatar Sep 12 '24 16:09 rationalism

I just tried this one and it compiles without graph breaks:

@torch.inference_mode()
def optimize_weights_proximal_legacy(
    tensor: Tensor,
    scale: Tensor,
    zero: Tensor,
    min_max: list,
    axis: int = 0,
    device: Union[str, None] = None,
    opt_params: dict = {"lp_norm": 0.7, "beta": 1e1, "kappa": 1.01, "iters": 20},
    verbose: bool = False,
) -> tuple:

    if device is None:
        device = tensor.device
    else:
        device = torch.device(device)

    dtype = float16 if (device.type == "cuda") else float32
    W_f = tensor.to(dtype=dtype, device=device)
    scale = scale.to(dtype=dtype, device=device)
    zero = zero.to(dtype=dtype, device=device)

    # Params
    lp_norm = torch.tensor(max(opt_params["lp_norm"], 0.1), dtype=dtype, device=device)
    beta    = torch.tensor(opt_params["beta"],dtype=dtype, device=device)
    kappa   = torch.tensor(opt_params["kappa"], dtype=dtype, device=device)
    iters   = opt_params["iters"]

    best_error = torch.tensor(1e4, dtype=torch.float32, device=device)
    for i in range(iters):
        W_q = torch.round(W_f * scale + zero).clamp(min_max[0], min_max[1])
        W_r = (W_q - zero) / scale
        W_e = shrink_lp_op(W_f - W_r, beta, lp_norm)
        zero = torch.mean(W_q - (W_f - W_e) * scale, axis=axis, keepdim=True)
        beta *= kappa

        current_error = torch.abs(W_f - W_r).mean().float()
        if verbose:
            print(i, np.round(current_error, 6))
        if current_error < best_error:
            best_error = current_error
        else:
            break

    scale = scale.to(tensor.device)
    zero = zero.to(tensor.device)
    del W_f, W_q, W_r, W_e
    torch.cuda.empty_cache()

    W_q = torch.round(tensor * scale + zero).clamp(min_max[0], min_max[1])
    return W_q, scale, zero

then

import torch
device        = 'cuda:0'
backend       = 'torchao_int4' #"torchao_int4" (4-bit only) or "bitblas" (4-bit + 2-bit)
compute_dtype = torch.float16 if backend=="bitblas" else torch.bfloat16
cache_dir     = '.' 
model_id      = 'meta-llama/Meta-Llama-3-8B-Instruct'

########################################################################
#Load model
from transformers import AutoModelForCausalLM, AutoTokenizer
from hqq.models.hf.base import AutoHQQHFModel
from hqq.core.quantize import *

#Load
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_dir)
model     = AutoModelForCausalLM.from_pretrained(model_id, cache_dir=cache_dir, torch_dtype=compute_dtype, attn_implementation="sdpa")

#Quantize
#torch._dynamo.config.capture_scalar_outputs = True
Quantizer.optimize_weights = torch.compile(Quantizer.optimize_weights)
Quantizer.quantize = torch.compile(Quantizer.quantize)

quant_config = BaseQuantizeConfig(nbits=4, group_size=64, axis=1)
AutoHQQHFModel.quantize_model(model, quant_config=quant_config, compute_dtype=compute_dtype, device=device)

I wil ltake a look at Quantizer.quantize compilation now

mobicham avatar Sep 12 '24 17:09 mobicham

Works with Quantizer.quantize compiled as well! I suggest we do the following: -We remove the @torch.compile decorator and do the compilation outside, either like this

Quantizer.optimize_weights = torch.compile(Quantizer.optimize_weights)

or like this

Quantizer.quantize = torch.compile(Quantizer.quantize)

This way if the compilation breaks the user can simply skip this step.

Then in optimize_weights..., we need to move the parameters after dettermining the dtype, this way they have the exact same dtype as the other tensors:

     ....
    dtype = float16 if (device.type == "cuda") else float32
    W_f = tensor.to(dtype=dtype, device=device)
    scale = scale.to(dtype=dtype, device=device)
    zero = zero.to(dtype=dtype, device=device)
    # Params
    lp_norm = torch.tensor(max(opt_params["lp_norm"], 0.1), dtype=dtype, device=device)
    beta    = torch.tensor(opt_params["beta"],dtype=dtype, device=device)
    kappa   = torch.tensor(opt_params["kappa"], dtype=dtype, device=device)
    iters   = opt_params["iters"]
   ....

Ideally we do this for the optimize_weights_proximal_v2 version as well.

:pray:

mobicham avatar Sep 12 '24 17:09 mobicham

Closing this PR since there was no update for over a year, but feel free to re-open another one please!

mobicham avatar Oct 24 '25 23:10 mobicham