Reducing memory usage: removing useless logits computation in generate()
What does this PR do?
This is the PR related to the discussion in #30860.
I followed was has been done in Jamba and added support for the num_logits_to_keep argument in forward(). However, even if this argument is None, the logits will only be upcasted to float if labels are passed (in order to accurately compute the loss). Otherwise, the upcasting only happen in the generate() functions.
For now, I only modified Llama and Mistral, but if you agree on the changes I will add support for more models.
Benchmarks
Here I provide some benchmarks of the peak memory usage. For each input size, I generated 10 additional tokens. Of course, since for few additional tokens the memory peak scales only with the first forward pass (at least when computing the whole logits matrix), and that the first forward scales linearly with input size and batch size (with new attention algorithms), the gain is actually constant for all input sizes and generation methods (except for contrastive search, which artificially increase the batch size after the first forward, thus the memory usage is slightly different). However, I still provide results for all generation methods here for completeness.
Basically we get: Llama3 8B -> MIND-BLOWING 3.62 memory usage reduction factor (due to large vocabulary) Llama2 7B -> 1.17 reduction factor Mistral 7B -> 1.32 reduction factor
Note that the memory reduction shown here is on top of whatever gains #30536 already provides for small new additional tokens, as I am comparing memory with the main transformers branch after it was merged. It integrates very nicely with that last PR, as the last one was providing most benefits when generating more tokens, and this one provides gains for small new number of tokens.
greedy.pdf sample.pdf beam sample.pdf beam search.pdf group beam search.pdf contrastive search.pdf
Here is a link to the benchmark script: https://gist.github.com/Cyrilvallez/92f48e402aa2968c854a8128796f50c3
Who can review?
@ArthurZucker @gante Let me know what you think about the proposed changes!
cc @ArthurZucker @gante
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
btw, a ratio of 3x lower peak memory consumption is š„ š„ š„
I just added the change to more models and rebased to avoid conflicts with new commits in main! For Cohere-based models, I most notably computed a memory gain ratio of 6.68 due to the very large 256k vocabulary size šš„
Last thing to take into account is your comment about the signature @ArthurZucker but not sure I understood correctly what you wanted to do š¤
Make sure to rebase as the state of the main branch was changed quite a bit!
Will do! However, when playing with torch.compile, I noticed that adding a logger.warning_once() in the forward breaks the graph with the following error: Unsupported: call_method UserDefinedObjectVariable(Logger) warning_once [ConstantVariable()] {}. This is with PyTorch latest version (2.3.1). So I will make sure to change that/make it compatible as well.
DO NOT MERGE YET Everything else is good, but still need to sort out the logger.warning_once/compile issue
@ArthurZucker @gante everything is now ready.
From my tests, it seems like compile does not support any print-like functionality at the moment, either from print, logger or warnings.
I first wanted to add a logger.warning_once_compile_safe function which I thought would simplify things and come in handy in the future as well, but couldn't because it needs to import torch in the logging module which breaks things.
So I just added a compile check everywhere.
@ArthurZucker is this planned for review this week? Iām pretty eager to consume this PR.
Yes! Reviewing asap!
Looking forward to testing this out, gemma2 uses a lot of memory otherwise and is a top model.
Hey @Cyrilvallez, thanks for your work. Just checking in regarding this PR. Do you have a plan to finish it up some time soon? I'm very excited for it to land!
Hi @ringohoffman, don't worry I am not forgetting about this š I'm currently on vacation so I will try to wrap it up quickly end of August when I come back if I have time. Worst case scenario, it will be ready mid-September.
In the meantime, you can install transformers from my fork if you want to already benefit from it (pip install git+https://github.com/Cyrilvallez/transformers@logits-dtype). Or even better, you can clone my fork and rebase it on transformers/main to get all the new stuff + this PR.
Does this PR actually fixes gemma2 or just Gemma?
Gemma2 was not released yet when I started this, but don't worry I will add it as well, it's on the roadmap š¤
@ArthurZucker I added support for Gemma2 as well as tests, ready for last review š¤ Red CIs are not related to the PR
Can you rebase to fix the CI? š¤
My bad, there was actually a check-copy inconsistency that was my doing. Fixed it!
Last red CI is the libtorch_cuda.so: cannot open shared object file: No such file or directory issue that was also present on recent PRs.
@Cyrilvallez the CI images built after this PR gets merged should fix the CI issue we're seeing š¤
Indeed it worked, @ArthurZucker everything is now green!
@Cyrilvallez rebasing (again) will get rid of the CI error
(bear with us, we're chasing these CI blockers š )
No worries! All good on the CIs and ready to be merged š¤
Congrats, @Cyrilvallez!
When is this planned to be released? @ArthurZucker @gante
@ringohoffman our rule of thumb is to release every month, so it should be in ~2 weeks š¤