litgpt icon indicating copy to clipboard operation
litgpt copied to clipboard

Support for KV caching and batched inference

Open mseeger opened this issue 10 months ago • 25 comments

Adds abstraction for key-value caches, implements batched inference.

I am also adding two baseline KV caches, the default one from before (all KV are stored) and a last-recent one.

The abstraction contains methods not used by these baselines, but they are required to implement more advanced KV caches such as Heavy Hitter Oracle (H2O).

I have implemented some of these, but I may not be allowed to contribute them here (working for a company). I'll see what I can do.

mseeger avatar Feb 06 '25 09:02 mseeger

Hey, great work @mseeger .

Can we decouple things a lot, though?

Some initial thoughts:

  • I would prefer if we kept the KVCache initialization as in the current version (i.e. that you initialize the model, then potentially adjust the max seq len and then initialize the KVCache) in this PR. Adding this to the init parameters seems orthogonal to the other changes.
  • We do have batched generation today. Can we please split changes to batched generation from the KVCache improvements. We probably don't want to do batching via lists of tensors. I'm currently looking at passing in "packed" input/input_pos sequences but these changes. Changing the existing tests should be a bit of a red flag, as it will screw existing users to change the API (we can do this if we need to, but TBH I am not convinced this is the case).
  • In general, can we be very conservative with adding arguments? For optional arguments, we should look into making them keyword-only unless there is a good reason not to.
  • We do keep control flow simple. self._default_kv_cache = False is not a good idea.
  • I'm not sure I understand the both_in_parallel. Maybe the right time to add it and the associated refactors is when they are used?
  • I'm generally a bit weary of the amount of data structures and cases that are being passed around here, those add a lot of complexity. To my mind, this likely means that the right abstraction has not yet been found. Maybe integrating KVCache and SDPA more could be a thing, but I am not sure.
  • In general, we do not want to do the cache setup during the forward. Please keep the initialization separate. I think we are rather seeing movement towards less of a distinction between pre-fill and next token, so this seems a bit in the wrong direction.

Again, super good stuff in the PR! I think there are a few things to split out and consider individually and then maybe we can have a video call about the core KVCache things, wdyt?

Thanks for the initiative for better KVCacheing!

t-vi avatar Feb 06 '25 10:02 t-vi

Hello, sure we can have a call, I am in the central Europe (Germany) time zone.

mseeger avatar Feb 06 '25 15:02 mseeger

My impression was that batched generation is not really there. But if it is, I don't ask to change it.

One thing is important through. KV caches really work by filling positions sequentially. So, you filled positions 0:(T-1), you need to continue with T, or with T:(T+k). The current API of just passing some position indexes is really not going to work.

mseeger avatar Feb 06 '25 15:02 mseeger

Also, the implementation right now allows you to send in KV cache objects from the start. If you do not do that, it will create them by default. This is done by set_kv_cache. If you also do not do that, it is done in the first forward with for_prefill=True.

Note that prefill here means that I can do a single pass, and the cache can take it all, without having to evict anything. It does not mean that this will encode even the shortest prompt in the batch. If prompts are longer than the max prefill length, you need to do it sequentially in chunks.

Maybe there is an easier way, we can discuss.

mseeger avatar Feb 06 '25 15:02 mseeger

It is annoying I cannot show you the KV cache code I have. But in a talk, I could explain why a few things are the way they are. Of course, I am not on top of other constraints you guys have.

mseeger avatar Feb 06 '25 15:02 mseeger

You may ask why KVCache.prefill? The main reason is that you want to use SDPA whenever you can, but SDPA cannot return the attention weights, which some KV cache algorithms (H2O) need in order to decide what to evict next.

We can do things so the very first call to the model, with input_pos=0, is doing this. So, instead of

model(x, for_prefill=True)

you'd call

model(x, input_pos=0)

This I could do. That would indeed be a little simpler.

mseeger avatar Feb 06 '25 15:02 mseeger

@t-vi Let me know what the next steps here should be. If I understand correctly, I could:

  • Get rid of for_prefill parameter, and use input_pos=0 instead
  • Don't create default KV cache in forward and rather fail the call if input_pos is used, s.t. user needs to call set_kv_cache
  • You don't seem to approve of passing the KV caches at construction (if user does not want to use default ones). Would you rather use set_kv_cache for that?

mseeger avatar Feb 06 '25 19:02 mseeger

Hi, so I think we should try to break things down.

We could either start with the core caching itself and try to see how to integrate it with minimal changes or see what is the deal with batching and prefill first. I sent to your gmail address to find a good time to discuss.

t-vi avatar Feb 09 '25 09:02 t-vi

Hello @t-vi , let me try to break things down. Changes are these:

  1. KVCache and its implementations. This replaces the default cache, which just stores everything. No behavior changes.
  2. Caches for each layer can be passed when model is created. Before, there is set_kvcache, which creates the default caches. If nothing is done at all, default caches are created when first needed. This is a change. Before, it would create an exception.
  3. Refactoring of generation code: This works for batch generation now, and single sequence generation is a special case. Inside, this properly supports large prompts by splitting generation into prefill (as large as caches allow), and then aequential blocks of desired length.

mseeger avatar Feb 21 '25 07:02 mseeger

If I understand you correctly, you complain about 2., especially the automatic creation of default cache when nothing is done, and the change of __init__ of GPT. This, I can work on. I could to the following:

  • Allow passing KV caches per layer in set_kvcache (or have another method?)
  • Create default KV caches by calling set_kvcache. If this is not done, calling forward for inference fails, so no cache is created automatically

Would that be what you prefer?

mseeger avatar Feb 21 '25 07:02 mseeger

As for 1. and 3., in the end, they go together, but I can try split it into two. I'd first do 1., keeping the generation code in place, which would however not work for batches and not support the sequential processing of prompts properly.

First doing 3. is not really sensible, because it requires things from 1.

What do you think?

mseeger avatar Feb 21 '25 07:02 mseeger

Note that with DeepSeek (I am involved trying to bring this to Hugging Face), there is a lot of movement now not to ignore KV caching in the future. They even released a paper now how they can train with large contexts.

mseeger avatar Feb 21 '25 07:02 mseeger

OK, I did 2) AFAI understand. I'd work on 1) once I find time.

mseeger avatar Feb 24 '25 15:02 mseeger

No idea why all these tests are failing. Tests work for me locally.

mseeger avatar Feb 24 '25 16:02 mseeger

@t-vi Maybe I can change your mind about first keeping the current generation code in place, and only contribute the KV cache support?

This is quite a bit of extra work for me, and new code of mine has a number of improvements. in particular, the current code does not really do batch generation, it is marked with several TODO and is not used.

If we could have a chat, I'd appreciate that.

mseeger avatar Feb 26 '25 21:02 mseeger

Your CI system seems to be broken still.

mseeger avatar Feb 27 '25 10:02 mseeger

Out of curiosity: Why do you object to batch prompts being a list of tensors? In the end, they can have wildly different lengths, and there is not much you can do against that (sure, if you get lots of requests, you can maybe cluster them, but doing this too much delays requests, so increases latency).

Also, you really don't want to push PAD tokens into models just because a prompt in a batch happens to be shorter than others. The model, not being trained on this, would certainly get confused. And since you need to start token-by-token forward for generation, you really gain nothing by padding prompts.

I always thought if this as some kind of TensorFlow artefact when all tensors had to be allocated up front, etc. But I thought we have overcome this with PyTorch.

mseeger avatar Feb 27 '25 11:02 mseeger

Hey, sorry, I am totally swamped, still want to have a video call to chat.

Out of curiosity: Why do you object to batch prompts being a list of tensors? In the end, they can have wildly different lengths, and there is not much you can do against that (sure, if you get lots of requests, you can maybe cluster them, but doing this too much delays requests, so increases latency).

Because lists are a lot less nice to work with in various setups passing to kernels, cudagraphs etc.

For somewhat homogeneous seq lengths, padding works fine. We are using it in production, so I'm doubting claims that it does not work. It does have limitations with the inhomogenous sequence lengths, which we want to support.

But the proper way to support this is packed sequences, i.e. pass in flat input_tokens, input_pos (i.e. 1d shape, no batch index) and then batch_seq_lens of shape batch_size. batch_seq_lens gives lengths for each batch item (and might even be 0).

This is hugely more flexible. It needs FlexAttention or somesuch https://pytorch.org/docs/stable/nn.attention.flex_attention.html to make it work efficiently in stock PyTorch.

t-vi avatar Feb 27 '25 11:02 t-vi

Let me know when is a good time. I am in Europe time zone

mseeger avatar Feb 27 '25 12:02 mseeger

After our call, I think I understand more what you mean. Something like an abstraction in multi-head attention, where the input are keys, values, query for the current input chunk, all the same size, but then this is bundled:

  • Take in keys, values and replace with KV-cached ones, so now keys, values are larger
  • Do the SDPA computation
  • Feed attention weights back to KV cache if needed
  • Return MHA outputs before final linear mapping

This makes a lot of sense, and is quite elegant.

mseeger avatar Mar 05 '25 15:03 mseeger

Your CI system seems to be broken still.

@mseeger shall be fixed now, thank you for your patience :)

Borda avatar Mar 18 '25 09:03 Borda

As discussed with @t-vi , I'll refactor this as stated in the comment above. Makes total sense

mseeger avatar Mar 19 '25 10:03 mseeger

@t-vi , is this what you had in mind? KVCache is now very simple, and so the code in CausalSelfAttention is not polluted by details of KV caching. I added DefaultKVCache which most KV caches will use.

mseeger avatar Mar 27 '25 17:03 mseeger

I'd be OK to take out my batched inference part here, but this means there won't be any. Do you have plans to add the batched inference code you talked about any time soon?

mseeger avatar Mar 27 '25 17:03 mseeger

OK, I've taken out the batched inference code. Still working on fixing the tests (and need to refactor speculative decoding), but this is essentially it.

mseeger avatar Apr 04 '25 19:04 mseeger

@t-vi , it would be great to get some feedback on this one, before I spent time on fixing tests for code which I need to change afterwards anyway.

mseeger avatar Apr 28 '25 13:04 mseeger

BTW: Even vLLM does not have a consistent support of different KV caching techniques and strategies. They just offer some blunt stuff like tensor parallelism and quantization.

This could be a real differentiating feature of LitGPT.

If you know another open source library that indexes on KV caching, and which you'd like to integrate with instead, please let me know.

mseeger avatar Apr 29 '25 06:04 mseeger

@t-vi , @Borda : Any change there will be some progress here?

I recognize this is a big PR. On the other hand, a decent support for selective (sparse) KV caching could be a real differentiator for LitGPT, in that none of the other OS libraries spend effort on it, not even vLLM.

I made quite some progress, also on fine-tuning with long contexts. I am trying to get approval to open source this. In my team, we start to use LitGPT quite seriously (HF being just too messy).

mseeger avatar May 12 '25 16:05 mseeger

One thing still missing is good support for batch inference without excessive padding.

mseeger avatar May 12 '25 16:05 mseeger

@t-vi , @Borda: Any sign of life on this PR? A few things:

  • This is not meant to be a single PR, but I am happy to chop it into smaller pieces. At the moment, I am simply using this PR to get my work done on the extra repo I am writing
  • I made substantial progress: I now have a method to compute gradients (!) with KV caching inside. To my best knowledgem, this does not exist so far in any OS repo. This could be a real strong point for LitGPT
  • I am getting closer to being allowed to open source my code (which is a library on top of LitGPT), but this needs some integration into your code, which is what I am proposing here
  • Finally, I'd argue my code is also cleaning a few things up

mseeger avatar May 27 '25 07:05 mseeger