litgpt icon indicating copy to clipboard operation
litgpt copied to clipboard

`Chat` consumes more VRAM than `Generate`

Open Andrei-Aksionov opened this issue 1 year ago • 7 comments

Bug description

Hi there 👋

While working on integration of Gemma 2 (9b variant) I noticed that while a regular generate script barely fits into L4 with 24 GB, a chat script throws an OOM error. It looks like it's not a Gemma 2 9b specific error.

I tried a couple of other models, in a regular and a quantized form, with a single prompt:

What is the distance between Earth and the Moon?

Machine: Lightning Studio with 1xL4.

Model Chat Generate $Δ$ Chat (bnb.nf4) Generate (bnb.nf4) $Δ$
Phi-3 9.29 7.78 1.51 4.33 2.83 1.5
TinyLlama (chat) 2.60 2.30 0.3 1.38 1.07 0.31
Gemma 1 7b-it 20.58 18.83 1.75 11.45 9.69 1.76
Gemma 2 9b-it OOM 20.58 - 16.48 10.98 5.5

* memory is in GB

Chat is essentially a generate script running in a loop. It should not consume more memory, at least if a single prompt is provided. Since the difference in memory consumption between a regular and a quantized model stays the same, I assume, without even looking at the code, there is something wrong with memory preallocation (kv cache?).

What operating system are you using?

Linux

LitGPT Version

Version: 0.4.3.dev0

Andrei-Aksionov avatar Jul 07 '24 12:07 Andrei-Aksionov

Thanks for reporting. Yes, this is weird. My first thought would also be that it's something with the KV cache. I assume the maximum new tokens length is the same, right?

rasbt avatar Jul 07 '24 12:07 rasbt

It looks like in generate script we provide how many tokens to generate, and the value is calculated as the length of the prompt + max_new_token: https://github.com/Lightning-AI/litgpt/blob/3a4526ef9f1f107e78b8ab9dc537a144beb2d680/litgpt/generate/base.py#L261

While in chat script we don't do that and simply use the model's max sequence length: https://github.com/Lightning-AI/litgpt/blob/3a4526ef9f1f107e78b8ab9dc537a144beb2d680/litgpt/chat/base.py#L126-L128

Still, it shouldn't affect VRAM consumption, since the output has the same length, meaning that the generation is stopped by eos token, rather than this limit.

I'll investigate further after I finish with the Gemma 2.

Andrei-Aksionov avatar Jul 07 '24 12:07 Andrei-Aksionov

Good observation, this could be related. We can debug this by first making these two consistent.

rasbt avatar Jul 07 '24 13:07 rasbt

could be that in generate/base.py, we set

# set the max_seq_length to limit the memory usage to what we need
model.max_seq_length = max_returned_tokens

and we don't do this in Chat!

aniketmaurya avatar Jul 13 '24 21:07 aniketmaurya

Yes, @aniketmaurya, that is exactly the reason. As I stated above, in the generate script the code overrides model.max_seq_length to be len(prompt) + max_new_tokens, and then this value is used when preallocating space for kv-cache.

In the chat script this is not done and kv-cache always has the size of the max_seq_length which is equal to block_size from the config.

I guess it was done in this way because in chat mode you don't know what is the length of all prompts beforehand. But it's easily fixable. We can preallocate kv-cache for the first turn in the same fashion as in the generate script and then, if in the current turn the length of the prompt is greater than all the previous ones had, we can recreate kv-cache.

I'll create a PR on Monday with the fix.

Andrei-Aksionov avatar Jul 14 '24 10:07 Andrei-Aksionov

But it's easily fixable. We can preallocate kv-cache for the first turn in the same fashion as in the generate script and then, if in the current turn the length of the prompt is greater than all the previous ones had, we can recreate kv-cache.

Note that this will incur recompilations if compilation is enabled (which chat.py supports). You might want to enable that of optimization only when compilation is not enabled

carmocca avatar Jul 15 '24 12:07 carmocca

Good point. I guess an additional message saying that it can lead to a higher memory consumption due to a large kv-cache will be helpful.

Andrei-Aksionov avatar Jul 15 '24 12:07 Andrei-Aksionov

This reminds me, at some point we discussed an option like optimize="compute" | "memory" for less advanced users, and this could be a good trade-off for such a setting

rasbt avatar Jul 15 '24 14:07 rasbt