Easy-Transformer
Easy-Transformer copied to clipboard
[Question] demo of 4bit quantized Llama -- what's next?
I made a demo of loading 4bit quantized Llama2, it seems to work and uses 6G GPU RAM. Limitations:
- only supports Llama, only 4bit quantization
- only tested with Llama-7b, and haven't tested yet if all hooks work
- does not work with muliple GPUs properly
- the code needs some refactoring to make it cleaner
branch: https://github.com/coolvision/TransformerLens/tree/llama_4bit demo: https://github.com/coolvision/TransformerLens/blob/llama_4bit/demos/Llama_quantized.ipynb
It requires Torch 2.1.1 & bitsandbytes lib, to install:
poetry source add --priority=supplemental torch https://download.pytorch.org/whl/cu118
poetry add torch==2.1.1+cu118 --source torch
pip install bitsandbytes
My questions are:
- Would a PR with only Llama support make sense, with plans to add other models later? Or should all models be supported?
- What are the priorities for adding next features: supporting more models, more quantization settings, support for multiple GPUs?
Cool! If you can get caching and patching to work, this would be a very exciting addition. It'd be best to support as many models as possible, but even a MVP PR to just support LLaMA would be great.
I'd prioritise more models > multi-GPU > quantization settings (what do you mean by quantization settings? Eg 8 bit?)
what do you mean by quantization settings? Eg 8 bit?
Yes, with 4-bit quantization, bitsandbytes just dequanizes weights and multiplies with torch:
output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)
So all the inference is done in fp16 or fp32. And the only part affected by quantization is multiplying large weights matrices With 8-bit, it tries to use efficient mixed-precision inference, which makes is a bit harder to support.
Added a PR: #486