KV cache optimization with paged attention
Feature request
Paged attention has been enabled by a lot of server engine, e.g., vllm, tensorrt-llm
Motivation
KV cache is used to reduce computation for Decoder layer but it also bring memory overheads, for example, when we use beam search, the kv_cache should be reordered according to latest beam idx and the current key/value should also be concat with kv_cache in the attention layer to get entire context to do scale dot product. When the sequence is very long, the memory overhead will be performance bottleneck.
Your contribution
No PR yet
cc @gante (I think this is closest to your work - sorry if wrong! )
@jgong5
Hi @liangan1 👋
We are close to introducing a new cache abstraction (https://github.com/huggingface/transformers/pull/26681). I believe that, after this PR is merged, adding paged attention would become directly applicable on top of it :)
Would you be interested in adding it to transformers?
Hi @liangan1 👋
We are close to introducing a new cache abstraction (#26681). I believe that, after this PR is merged, adding paged attention would become directly applicable on top of it :)
Would you be interested in adding it to
transformers?
Sure. We are pleasure to contribute more kv_cache related optimizations.
Awesome, I will let you know when the cache abstraction is ready!
Thanks.
@liangan1 the cache abstraction will be merged today, so you can start working on top of it. Happy to provide pointers and suggestions! 🙌
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
As of the latest release if flash attention v2.5 paged kv cache is now supported. https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#25-paged-kv-cache. This being implemented into transformers would be pretty awesome, specially when it can stack with quantized kv cache, allowing for more than 100,00k tokens on consumer gpu’s, if you have 64gb of shared memory then like 500,000 tokens of context, on a 7b 4bit model.
@gante hi Joao, I am wondering if there are plans to implement better scheduling in GenerationMixin.generate in case a large input batch is passed (larger than what can be processed at a time), and in cases some sequences are finished earlier than other?
Or in the near future, should we expect GenerationMixin.generate to keep processing dummy tokens for finished sequences in a batch, and not attempting to maximize batch size in case of many requests?