`Chat` consumes more VRAM than `Generate`
Bug description
Hi there 👋
While working on integration of Gemma 2 (9b variant) I noticed that while a regular generate script barely fits into L4 with 24 GB, a chat script throws an OOM error. It looks like it's not a Gemma 2 9b specific error.
I tried a couple of other models, in a regular and a quantized form, with a single prompt:
What is the distance between Earth and the Moon?
Machine: Lightning Studio with 1xL4.
| Model | Chat | Generate | $Δ$ | Chat (bnb.nf4) | Generate (bnb.nf4) | $Δ$ |
|---|---|---|---|---|---|---|
| Phi-3 | 9.29 | 7.78 | 1.51 | 4.33 | 2.83 | 1.5 |
| TinyLlama (chat) | 2.60 | 2.30 | 0.3 | 1.38 | 1.07 | 0.31 |
| Gemma 1 7b-it | 20.58 | 18.83 | 1.75 | 11.45 | 9.69 | 1.76 |
| Gemma 2 9b-it | OOM | 20.58 | - | 16.48 | 10.98 | 5.5 |
* memory is in GB
Chat is essentially a generate script running in a loop. It should not consume more memory, at least if a single prompt is provided. Since the difference in memory consumption between a regular and a quantized model stays the same, I assume, without even looking at the code, there is something wrong with memory preallocation (kv cache?).
What operating system are you using?
Linux
LitGPT Version
Version: 0.4.3.dev0
Thanks for reporting. Yes, this is weird. My first thought would also be that it's something with the KV cache. I assume the maximum new tokens length is the same, right?
It looks like in generate script we provide how many tokens to generate, and the value is calculated as the length of the prompt + max_new_token:
https://github.com/Lightning-AI/litgpt/blob/3a4526ef9f1f107e78b8ab9dc537a144beb2d680/litgpt/generate/base.py#L261
While in chat script we don't do that and simply use the model's max sequence length: https://github.com/Lightning-AI/litgpt/blob/3a4526ef9f1f107e78b8ab9dc537a144beb2d680/litgpt/chat/base.py#L126-L128
Still, it shouldn't affect VRAM consumption, since the output has the same length, meaning that the generation is stopped by eos token, rather than this limit.
I'll investigate further after I finish with the Gemma 2.
Good observation, this could be related. We can debug this by first making these two consistent.
could be that in generate/base.py, we set
# set the max_seq_length to limit the memory usage to what we need
model.max_seq_length = max_returned_tokens
and we don't do this in Chat!
Yes, @aniketmaurya, that is exactly the reason.
As I stated above, in the generate script the code overrides model.max_seq_length to be len(prompt) + max_new_tokens, and then this value is used when preallocating space for kv-cache.
In the chat script this is not done and kv-cache always has the size of the max_seq_length which is equal to block_size from the config.
I guess it was done in this way because in chat mode you don't know what is the length of all prompts beforehand. But it's easily fixable. We can preallocate kv-cache for the first turn in the same fashion as in the generate script and then, if in the current turn the length of the prompt is greater than all the previous ones had, we can recreate kv-cache.
I'll create a PR on Monday with the fix.
But it's easily fixable. We can preallocate kv-cache for the first turn in the same fashion as in the generate script and then, if in the current turn the length of the prompt is greater than all the previous ones had, we can recreate kv-cache.
Note that this will incur recompilations if compilation is enabled (which chat.py supports). You might want to enable that of optimization only when compilation is not enabled
Good point. I guess an additional message saying that it can lead to a higher memory consumption due to a large kv-cache will be helpful.
This reminds me, at some point we discussed an option like optimize="compute" | "memory" for less advanced users, and this could be a good trade-off for such a setting