ao
ao copied to clipboard
[wip] SpinQuant
Corresponding issue: #579
This PR adds SpinQuant integration to pytorch/ao
. See the paper for details: https://arxiv.org/abs/2405.16406.
Initial results on Llama-2-7b are shown below (measured by Wikitext word perplexity).
Model | Quantization | Baseline | SpinQuant (R2) | SpinQuant (R4) | SpinQuant (R2+R4) |
---|---|---|---|---|---|
Llama-2-7B | None | 12.23 | 12.23 | 12.24 | 12.24 |
int8dq | 12.35 | 12.35 | 12.35 | 12.35 | |
int8wo | 12.24 | 12.26 | 12.26 | 12.27 | |
int4wo-64 | 12.87 | 12.85 | 12.82 | 12.80 | |
int4wo-128 | 13.21 | 13.27 | 13.20 | 13.20 |
TODO
- [x] implement R2
- [x] implement R4
- [x] implement layernorm weight fusion into linear layers (footnote 3 in the paper)
- [x] implement R1
- [ ] ~implement R3~
- [ ] ~Cayley optimization for R1 and R2 (not sure how feasible this is for inference -- it takes them 1hr to run Cayley optimization on 8x A100 GPUs for R1 and R2 using 800 samples of WikiText2 calibration dataset)~