ao
ao copied to clipboard
[wip] SpinQuant
Corresponding issue: #579
This PR adds SpinQuant integration to pytorch/ao. See the paper for details: https://arxiv.org/abs/2405.16406.
Initial results on Llama-2-7b are shown below (measured by Wikitext word perplexity).
| Model | Quantization | Baseline | SpinQuant (R2) | SpinQuant (R4) | SpinQuant (R2+R4) |
|---|---|---|---|---|---|
| Llama-2-7B | None | 12.23 | 12.23 | 12.24 | 12.24 |
| int8dq | 12.35 | 12.35 | 12.35 | 12.35 | |
| int8wo | 12.24 | 12.26 | 12.26 | 12.27 | |
| int4wo-64 | 12.87 | 12.85 | 12.82 | 12.80 | |
| int4wo-128 | 13.21 | 13.27 | 13.20 | 13.20 |
TODO
- [x] implement R2
- [x] implement R4
- [x] implement layernorm weight fusion into linear layers (footnote 3 in the paper)
- [x] implement R1
- [ ] ~implement R3~
- [ ] ~Cayley optimization for R1 and R2 (not sure how feasible this is for inference -- it takes them 1hr to run Cayley optimization on 8x A100 GPUs for R1 and R2 using 800 samples of WikiText2 calibration dataset)~
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/983
- :page_facing_up: Preview Python docs built from this PR
Note: Links to docs will display an error until the docs builds have been completed.
:white_check_mark: No Failures
As of commit fb3882f110895fd528875affd37abf21a29c961d with merge base 107e378f5ecc773d4f0c63c78218ad663061951c ():
:green_heart: Looks good so far! There are no failures yet. :green_heart:
This comment was automatically generated by Dr. CI and updates every 15 minutes.
Hey this is looking nice so far, long term we probably want to make these tensor subclasses so that we can make serialization easier. that way rather than having to load model -> convert model -> load checkpoint, you can just do load model -> load checkpoint
not absolutely critical but long term it looks like there may be multiple use cases/apis for spin quant, one explicitly for the Cayley QAT and one not, and unifying them based on serialization will make composability much nicer.
Good to know @HDCharles, I'll keep the tensor subclasses in mind. I was wondering, will the choice to integrate this into torchao depend on the performance delta it produces? Currently, there is some Wikitext perf improvement but it's perhaps not that significant.
Update: I'm currently somewhat stuck on this PR. The R2 and R4 matrices are both implemented, and show small perplexity improvements for in4wo-64 quantization (not much though, see table above). I've tried to implement it as much as possible in accordance with the SpinQuant implementation, but these are the best performance results I can achieve thus far (and not quite as good as the results in the paper). What still remains is the R3 rotation and R1 using Cayley optimization.
The R3 rotation is a bit tricky to implement because it requires a modification of Attention.forward() in the middle of the function, after the apply_rotary_emb calls:
https://github.com/pytorch/ao/blob/09b8b3c3156b99fa8356e2df30c1a65cbbf5cc91/torchao/_models/llama/model.py#L290-L302
In the SpinQuant repo they use a monkeypatch solution, but the code becomes a bit ugly in that case. At the same time, they show in the paper that R3 has a minimal effect on performance (table 3), so I'm also not sure how much it's worth to implement.
Lastly, I have not added the R1 matrices, which would require adding a Cayley optimization procedure. Currently, the SpinQuant changes are immediately applicable at inference time, but running Cayley optimization would require some time to complete (they report ~1hr to run Cayley optimization on 8x A100 GPUs for R1 and R2 using 800 samples of WikiText2 calibration dataset). I guess it could also be possible to train these matrices once for a model like Llama-7B and include them as add-on weights.
I would very much appreciate some feedback on how to proceed with this.
I have unblocked myself somewhat regarding the R1 rotation matrices: the authors provide downloads for the optimized R1/R2 weights. I could try these out to see what kind of performance difference can be expected before implementing the Cayley optimization here. My only concern is that their Llama implementation might not be 100% identical as in torchao, which could mean that the R1 weights might not work as well, but it seems worth trying out, anyway.
i think we can merge it and continue working on it regardless, accuracy improvements are definitely a good metric to see how useful it is though. Even if you look in their paper, for 4-16-16, the improvement of SpinQuant is pretty small even with cayley optimization. Its mostly 4-4-16 where it starts to outperform other methods by a significant margin. We're working on getting some kernels for that in the next 1-2 weeks so it may be more useful to that use case. For now i'd do accuracy benchmarks on groupsize=32 rather than 64/128 since thats the minimum batchsize.
Yeah the monkeypatch is pretty messy, feels like we can do this in a better way with either tensor subclasses or something else.
I'll do a final reformat and add some more results in the next few days @HDCharles
Hi @tobiasvanderwerff, do you mind reformatting hadamard_utils.py so we don't end up with a 10k line file? I feel you can even separate it into a separate file like _hadamard_matrices.py, so it's easier to review the other parts of hadamard_utils.py
Hi @tobiasvanderwerff, do you mind reformatting
hadamard_utils.pyso we don't end up with a 10k line file? I feel you can even separate it into a separate file like_hadamard_matrices.py, so it's easier to review the other parts ofhadamard_utils.py
Hi @tobiasvanderwerff @andrewor14 could you use this implementation https://fburl.com/code/d3nuagm4? It's faster and much easier to read.
its faster? do you have a link to benchmarks?
its faster? do you have a link to benchmarks?
The benchmark is in the summary of D61891002.
Hi @tobiasvanderwerff, great work! I am wondering if you tested the end-to-end generation performance(tokens/s)?
I have not tested tokens/s generation @yiliu30, but I can test this if you want.
I have not tested tokens/s generation @yiliu30, but I can test this if you want.
Thank you, @tobiasvanderwerff ! I'm primarily interested in studying the computational overhead introduced by r4, and I was wondering if the hardmard_transform might break the torch.compile.
Thanks for bringing this up @yiliu30 -- I tested this and it looks like the custom Hadamard transform kernel indeed breaks torch.compile. I'll investigate this and get back to you.
@yiliu30 FYI I fixed the issue with torch.compile -- you can see the benchmark results here.