transformer_nuggets
transformer_nuggets copied to clipboard
add an option to quantize in chunks
trafficstars
Repro:
from torchao.dtypes import to_nf4
from transformer_nuggets.quant.nf4_tensor import NF4Tensor
import torch
from pathlib import Path
from transformer_nuggets.utils.benchmark import save_memory_snapshot
import logging; logging.basicConfig(level=logging.INFO)
torch.set_default_device("cuda:0")
mem_reserved = torch.cuda.max_memory_reserved()
print(f"{mem_reserved / 1e9} before anything")
z = torch.rand((128256, 4096), dtype=torch.bfloat16)
mem_reserved = torch.cuda.max_memory_reserved()
print(f"{mem_reserved / 1e9} after bf16 tensor allocation")
with save_memory_snapshot(Path("nf4_memory")):
nf4_tensor_new = NF4Tensor.from_tensor(z)
mem_reserved = torch.cuda.max_memory_reserved()
print(f"{mem_reserved / 1e9} after nf4 tensor created")
total_bytes = z.element_size() * z.numel() + nf4_tensor_new.element_size() * nf4_tensor_new.numel()
nf4_tensor_old = to_nf4(z)
torch.testing.assert_close(nf4_tensor_new.quantized_data, nf4_tensor_old.quantized_data)
print(f"total GB consumed by bf16 and nf4 tensor: {total_bytes / 1e9}")
You sir are a genius!!
neat! should I treat this as a better version of processing on cpu OR are they orthogonal since we can chunk on cpu as well?
https://github.com/drisspg/lit-gpt/pull/1/files#diff-7c1edee33e1038f4f9b3ddfbd0274869a610ecedd7e43789f1e1d03bdf21fc54R97-R109
@weifengpy Probably, I will put up a PR against torchAO I only worked here since I wasnt able to pip install