ao icon indicating copy to clipboard operation
ao copied to clipboard

[wip] SpinQuant

Open tobiasvanderwerff opened this issue 1 year ago • 8 comments
trafficstars

Corresponding issue: #579

This PR adds SpinQuant integration to pytorch/ao. See the paper for details: https://arxiv.org/abs/2405.16406.

Initial results on Llama-2-7b are shown below (measured by Wikitext word perplexity).

Model Quantization Baseline SpinQuant (R2) SpinQuant (R4) SpinQuant (R2+R4)
Llama-2-7B None 12.23 12.23 12.24 12.24
int8dq 12.35 12.35 12.35 12.35
int8wo 12.24 12.26 12.26 12.27
int4wo-64 12.87 12.85 12.82 12.80
int4wo-128 13.21 13.27 13.20 13.20

TODO

  • [x] implement R2
  • [x] implement R4
  • [x] implement layernorm weight fusion into linear layers (footnote 3 in the paper)
  • [x] implement R1
  • [ ] ~implement R3~
  • [ ] ~Cayley optimization for R1 and R2 (not sure how feasible this is for inference -- it takes them 1hr to run Cayley optimization on 8x A100 GPUs for R1 and R2 using 800 samples of WikiText2 calibration dataset)~

tobiasvanderwerff avatar Oct 01 '24 09:10 tobiasvanderwerff

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/983

Note: Links to docs will display an error until the docs builds have been completed.

:white_check_mark: No Failures

As of commit fb3882f110895fd528875affd37abf21a29c961d with merge base 107e378f5ecc773d4f0c63c78218ad663061951c (image): :green_heart: Looks good so far! There are no failures yet. :green_heart:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

pytorch-bot[bot] avatar Oct 01 '24 09:10 pytorch-bot[bot]

Hey this is looking nice so far, long term we probably want to make these tensor subclasses so that we can make serialization easier. that way rather than having to load model -> convert model -> load checkpoint, you can just do load model -> load checkpoint

not absolutely critical but long term it looks like there may be multiple use cases/apis for spin quant, one explicitly for the Cayley QAT and one not, and unifying them based on serialization will make composability much nicer.

HDCharles avatar Oct 03 '24 01:10 HDCharles

Good to know @HDCharles, I'll keep the tensor subclasses in mind. I was wondering, will the choice to integrate this into torchao depend on the performance delta it produces? Currently, there is some Wikitext perf improvement but it's perhaps not that significant.

tobiasvanderwerff avatar Oct 03 '24 06:10 tobiasvanderwerff

Update: I'm currently somewhat stuck on this PR. The R2 and R4 matrices are both implemented, and show small perplexity improvements for in4wo-64 quantization (not much though, see table above). I've tried to implement it as much as possible in accordance with the SpinQuant implementation, but these are the best performance results I can achieve thus far (and not quite as good as the results in the paper). What still remains is the R3 rotation and R1 using Cayley optimization.

The R3 rotation is a bit tricky to implement because it requires a modification of Attention.forward() in the middle of the function, after the apply_rotary_emb calls:

https://github.com/pytorch/ao/blob/09b8b3c3156b99fa8356e2df30c1a65cbbf5cc91/torchao/_models/llama/model.py#L290-L302

In the SpinQuant repo they use a monkeypatch solution, but the code becomes a bit ugly in that case. At the same time, they show in the paper that R3 has a minimal effect on performance (table 3), so I'm also not sure how much it's worth to implement.

Lastly, I have not added the R1 matrices, which would require adding a Cayley optimization procedure. Currently, the SpinQuant changes are immediately applicable at inference time, but running Cayley optimization would require some time to complete (they report ~1hr to run Cayley optimization on 8x A100 GPUs for R1 and R2 using 800 samples of WikiText2 calibration dataset). I guess it could also be possible to train these matrices once for a model like Llama-7B and include them as add-on weights.

I would very much appreciate some feedback on how to proceed with this.

tobiasvanderwerff avatar Oct 03 '24 13:10 tobiasvanderwerff

I have unblocked myself somewhat regarding the R1 rotation matrices: the authors provide downloads for the optimized R1/R2 weights. I could try these out to see what kind of performance difference can be expected before implementing the Cayley optimization here. My only concern is that their Llama implementation might not be 100% identical as in torchao, which could mean that the R1 weights might not work as well, but it seems worth trying out, anyway.

tobiasvanderwerff avatar Oct 03 '24 14:10 tobiasvanderwerff

i think we can merge it and continue working on it regardless, accuracy improvements are definitely a good metric to see how useful it is though. Even if you look in their paper, for 4-16-16, the improvement of SpinQuant is pretty small even with cayley optimization. Its mostly 4-4-16 where it starts to outperform other methods by a significant margin. We're working on getting some kernels for that in the next 1-2 weeks so it may be more useful to that use case. For now i'd do accuracy benchmarks on groupsize=32 rather than 64/128 since thats the minimum batchsize.

Yeah the monkeypatch is pretty messy, feels like we can do this in a better way with either tensor subclasses or something else.

HDCharles avatar Oct 03 '24 16:10 HDCharles

I'll do a final reformat and add some more results in the next few days @HDCharles

tobiasvanderwerff avatar Oct 05 '24 06:10 tobiasvanderwerff

Hi @tobiasvanderwerff, do you mind reformatting hadamard_utils.py so we don't end up with a 10k line file? I feel you can even separate it into a separate file like _hadamard_matrices.py, so it's easier to review the other parts of hadamard_utils.py

andrewor14 avatar Oct 07 '24 17:10 andrewor14

Hi @tobiasvanderwerff, do you mind reformatting hadamard_utils.py so we don't end up with a 10k line file? I feel you can even separate it into a separate file like _hadamard_matrices.py, so it's easier to review the other parts of hadamard_utils.py

Hi @tobiasvanderwerff @andrewor14 could you use this implementation https://fburl.com/code/d3nuagm4? It's faster and much easier to read.

wat3rBro avatar Oct 10 '24 21:10 wat3rBro

its faster? do you have a link to benchmarks?

HDCharles avatar Oct 11 '24 19:10 HDCharles

its faster? do you have a link to benchmarks?

The benchmark is in the summary of D61891002.

wat3rBro avatar Oct 11 '24 19:10 wat3rBro

Hi @tobiasvanderwerff, great work! I am wondering if you tested the end-to-end generation performance(tokens/s)?

yiliu30 avatar Oct 11 '24 23:10 yiliu30

I have not tested tokens/s generation @yiliu30, but I can test this if you want.

tobiasvanderwerff avatar Oct 14 '24 05:10 tobiasvanderwerff

I have not tested tokens/s generation @yiliu30, but I can test this if you want.

Thank you, @tobiasvanderwerff ! I'm primarily interested in studying the computational overhead introduced by r4, and I was wondering if the hardmard_transform might break the torch.compile.

yiliu30 avatar Oct 14 '24 14:10 yiliu30

Thanks for bringing this up @yiliu30 -- I tested this and it looks like the custom Hadamard transform kernel indeed breaks torch.compile. I'll investigate this and get back to you.

tobiasvanderwerff avatar Oct 14 '24 17:10 tobiasvanderwerff

@yiliu30 FYI I fixed the issue with torch.compile -- you can see the benchmark results here.

tobiasvanderwerff avatar Oct 15 '24 08:10 tobiasvanderwerff