transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Reducing memory usage: removing useless logits computation in generate()

Open Cyrilvallez opened this issue 1 year ago • 20 comments

What does this PR do?

This is the PR related to the discussion in #30860. I followed was has been done in Jamba and added support for the num_logits_to_keep argument in forward(). However, even if this argument is None, the logits will only be upcasted to float if labels are passed (in order to accurately compute the loss). Otherwise, the upcasting only happen in the generate() functions.

For now, I only modified Llama and Mistral, but if you agree on the changes I will add support for more models.

Benchmarks

Here I provide some benchmarks of the peak memory usage. For each input size, I generated 10 additional tokens. Of course, since for few additional tokens the memory peak scales only with the first forward pass (at least when computing the whole logits matrix), and that the first forward scales linearly with input size and batch size (with new attention algorithms), the gain is actually constant for all input sizes and generation methods (except for contrastive search, which artificially increase the batch size after the first forward, thus the memory usage is slightly different). However, I still provide results for all generation methods here for completeness.

Basically we get: Llama3 8B -> MIND-BLOWING 3.62 memory usage reduction factor (due to large vocabulary) Llama2 7B -> 1.17 reduction factor Mistral 7B -> 1.32 reduction factor

Note that the memory reduction shown here is on top of whatever gains #30536 already provides for small new additional tokens, as I am comparing memory with the main transformers branch after it was merged. It integrates very nicely with that last PR, as the last one was providing most benefits when generating more tokens, and this one provides gains for small new number of tokens.

greedy.pdf sample.pdf beam sample.pdf beam search.pdf group beam search.pdf contrastive search.pdf

Here is a link to the benchmark script: https://gist.github.com/Cyrilvallez/92f48e402aa2968c854a8128796f50c3

Who can review?

@ArthurZucker @gante Let me know what you think about the proposed changes!

Cyrilvallez avatar Jun 06 '24 15:06 Cyrilvallez

cc @ArthurZucker @gante

amyeroberts avatar Jun 06 '24 16:06 amyeroberts

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

btw, a ratio of 3x lower peak memory consumption is šŸ”„ šŸ”„ šŸ”„

gante avatar Jun 18 '24 13:06 gante

I just added the change to more models and rebased to avoid conflicts with new commits in main! For Cohere-based models, I most notably computed a memory gain ratio of 6.68 due to the very large 256k vocabulary size šŸš€šŸ”„

Last thing to take into account is your comment about the signature @ArthurZucker but not sure I understood correctly what you wanted to do šŸ¤“

Cyrilvallez avatar Jun 21 '24 09:06 Cyrilvallez

Make sure to rebase as the state of the main branch was changed quite a bit!

ArthurZucker avatar Jul 12 '24 09:07 ArthurZucker

Will do! However, when playing with torch.compile, I noticed that adding a logger.warning_once() in the forward breaks the graph with the following error: Unsupported: call_method UserDefinedObjectVariable(Logger) warning_once [ConstantVariable()] {}. This is with PyTorch latest version (2.3.1). So I will make sure to change that/make it compatible as well.

Cyrilvallez avatar Jul 12 '24 12:07 Cyrilvallez

DO NOT MERGE YET Everything else is good, but still need to sort out the logger.warning_once/compile issue

Cyrilvallez avatar Jul 12 '24 13:07 Cyrilvallez

@ArthurZucker @gante everything is now ready. From my tests, it seems like compile does not support any print-like functionality at the moment, either from print, logger or warnings. I first wanted to add a logger.warning_once_compile_safe function which I thought would simplify things and come in handy in the future as well, but couldn't because it needs to import torch in the logging module which breaks things. So I just added a compile check everywhere.

Cyrilvallez avatar Jul 17 '24 12:07 Cyrilvallez

@ArthurZucker is this planned for review this week? I’m pretty eager to consume this PR.

ringohoffman avatar Jul 24 '24 17:07 ringohoffman

Yes! Reviewing asap!

ArthurZucker avatar Jul 26 '24 10:07 ArthurZucker

Looking forward to testing this out, gemma2 uses a lot of memory otherwise and is a top model.

Oxi84 avatar Aug 01 '24 13:08 Oxi84

Hey @Cyrilvallez, thanks for your work. Just checking in regarding this PR. Do you have a plan to finish it up some time soon? I'm very excited for it to land!

ringohoffman avatar Aug 07 '24 00:08 ringohoffman

Hi @ringohoffman, don't worry I am not forgetting about this šŸ˜‰ I'm currently on vacation so I will try to wrap it up quickly end of August when I come back if I have time. Worst case scenario, it will be ready mid-September.

In the meantime, you can install transformers from my fork if you want to already benefit from it (pip install git+https://github.com/Cyrilvallez/transformers@logits-dtype). Or even better, you can clone my fork and rebase it on transformers/main to get all the new stuff + this PR.

Cyrilvallez avatar Aug 08 '24 06:08 Cyrilvallez

Does this PR actually fixes gemma2 or just Gemma?

Boubou78000 avatar Aug 08 '24 07:08 Boubou78000

Gemma2 was not released yet when I started this, but don't worry I will add it as well, it's on the roadmap šŸ¤—

Cyrilvallez avatar Aug 08 '24 08:08 Cyrilvallez

@ArthurZucker I added support for Gemma2 as well as tests, ready for last review šŸ¤— Red CIs are not related to the PR

Cyrilvallez avatar Aug 21 '24 10:08 Cyrilvallez

Can you rebase to fix the CI? šŸ¤—

ArthurZucker avatar Aug 21 '24 10:08 ArthurZucker

My bad, there was actually a check-copy inconsistency that was my doing. Fixed it! Last red CI is the libtorch_cuda.so: cannot open shared object file: No such file or directory issue that was also present on recent PRs.

Cyrilvallez avatar Aug 21 '24 14:08 Cyrilvallez

@Cyrilvallez the CI images built after this PR gets merged should fix the CI issue we're seeing šŸ¤—

gante avatar Aug 21 '24 17:08 gante

Indeed it worked, @ArthurZucker everything is now green!

Cyrilvallez avatar Aug 22 '24 06:08 Cyrilvallez

@Cyrilvallez rebasing (again) will get rid of the CI error

(bear with us, we're chasing these CI blockers šŸ™ )

gante avatar Aug 23 '24 08:08 gante

No worries! All good on the CIs and ready to be merged šŸ¤—

Cyrilvallez avatar Aug 23 '24 09:08 Cyrilvallez

Congrats, @Cyrilvallez!

When is this planned to be released? @ArthurZucker @gante

ringohoffman avatar Aug 23 '24 17:08 ringohoffman

@ringohoffman our rule of thumb is to release every month, so it should be in ~2 weeks šŸ¤—

gante avatar Aug 23 '24 17:08 gante