Paged Attention Based the Latest Cache Design
What does this PR do?
Based on the latest cache design on #PR26681, This PR implements the Paged Attention KV cache which is proposed by this paper.
Fixes # (issue#27303)
Who can review ?
@tomaarsen @gante @patrickvonplaten @jgong5 @jianan-gu
This should wait until the cache refactor is finished cc @gante: https://github.com/huggingface/transformers/pull/27407
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
pls help to reopen this PR, we will refine it based on latest transformers
@liangan1 reopened :D
@liangan1 reopened :D
Thanks. I will refresh this PR ASAP.
@gante I have rebased this PR and validated the functionality with llama model with both greedy&beam search, pls help to review. update function should be a good start point to understand it.
Thanks @gante I will refine it according to your comments.
Thank you for working on this, it seems like it's growing in the right direction 💪 Apologies for the delayed review, I went back to reread the paper.
In addition to the comments added in the code:
A - General comments:
- nit: whenever possible, try to prepend
_to the name of the methods when you don't expect to use the method outside the class (e.g.allocate->_allocate)- nit: raise exceptions instead of using
assertB - What would I like to see in the next review:
- A test to ensure the outputs with the original cache don't change when we switch to this new cache
- Some benchmarks -- the point of Paged Attention is performance, so there is no point in merging the code if it ends up being slower and/or consuming more memory 🤗
Refine according to A - General comments: and Add UT to ensure the output is not changed with new cache.
@gante Sorry for late reply due to holidays. The static cache is a good choice for greedy search which uses a large buffer to store the past key/value state add removes the concat overhead. But for the beam search, there are also reorder_cache overheads except for the concat overhead in the attention module, besides, the prompt can be shared for beam search to reduce the first token latency, especially for large batch or long sequence. The prompt sharing has the following advantages:
- Only 1/beam_size memory consumption than the native implementation.
- Better cache locality with prompt reuse.
In this PR, we enable paged attention which can apply prompt sharing. The first token can get about 3x performance improvement for beam=4 & input_len = 1024, but the next token get performance regression due to the token cache is stored discrete(reshape_and_cache) and the SPDA kernel can't support this format, So the low efficient way is to firstly copy the discrete key/value tokens state into cached key/value states(past_key_value.get_entire_context_states(key_states, value_states) and then call SDPA op. I make a quick validation with the paged attention kernel from Intel Extension for Pytorch in my local CPU machine and paged attention format aware SDPA kernel can obviously speedup the next token inference.
The code change like following:
The result('sdpa' attn_implementation, imperative mode) as following:
- top is dynamic cache and bottom is the paged attention
- ~50ms latency reduction for the next tokens with paged attention aware sdpa kernel:
- 3x speedup for the first token with prompt sharing.
So, if we can merge this PR into transformers, we will cooperate with PyTorch team to enable the related paged attention operators in PyTorch and PyTorch team is also pleasure to do this work.
- reshape_and_cache #which is used to store the past key/value.
- paged_attention #which dose the scale dot product based the kv_cache format of paged attention.
@gante just a soft reminder, can you help to review again.
Hi @liangan1 👋 I'm holding the review until this PR is merged, as it might change the API of caches -- https://github.com/huggingface/transformers/pull/29180
Hi @liangan1 👋 I'm holding the review until this PR is merged, as it might change the API of caches -- #29180
Thanks.
@liangan1 apologies for the slow reply rate: all caches now share a standard interface, and they are all standalone objects that can be passed in and out of the model :)
We are now at a point were we can accept PRs for new cache objects 🙌 The current prototype assumes the old StaticCache approach, with one object per attention block, which is no longer compatible.
@liangan1 apologies for the slow reply rate: all caches now share a standard interface, and they are all standalone objects that can be passed in and out of the model :)
We are now at a point were we can accept PRs for new cache objects 🙌 The current prototype assumes the old
StaticCacheapproach, with one object per attention block, which is no longer compatible.
Thanks for your info. we are also working on the Pagedattention design to Torchao project. And it is compatible to the semantics abstraction of transformers Will refine this PR ASAP.