transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Paged Attention Based the Latest Cache Design

Open liangan1 opened this issue 2 years ago • 12 comments

What does this PR do?

Based on the latest cache design on #PR26681, This PR implements the Paged Attention KV cache which is proposed by this paper.

Fixes # (issue#27303)

Who can review ?

@tomaarsen @gante @patrickvonplaten @jgong5 @jianan-gu

liangan1 avatar Nov 20 '23 07:11 liangan1

This should wait until the cache refactor is finished cc @gante: https://github.com/huggingface/transformers/pull/27407

patrickvonplaten avatar Nov 20 '23 11:11 patrickvonplaten

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Dec 20 '23 08:12 github-actions[bot]

pls help to reopen this PR, we will refine it based on latest transformers

liangan1 avatar Jan 02 '24 00:01 liangan1

@liangan1 reopened :D

gante avatar Jan 10 '24 13:01 gante

@liangan1 reopened :D

Thanks. I will refresh this PR ASAP.

liangan1 avatar Jan 11 '24 00:01 liangan1

@gante I have rebased this PR and validated the functionality with llama model with both greedy&beam search, pls help to review. update function should be a good start point to understand it.

liangan1 avatar Jan 15 '24 09:01 liangan1

Thanks @gante I will refine it according to your comments.

liangan1 avatar Jan 19 '24 01:01 liangan1

Thank you for working on this, it seems like it's growing in the right direction 💪 Apologies for the delayed review, I went back to reread the paper.

In addition to the comments added in the code:

A - General comments:

  1. nit: whenever possible, try to prepend _ to the name of the methods when you don't expect to use the method outside the class (e.g. allocate -> _allocate)
  2. nit: raise exceptions instead of using assert

B - What would I like to see in the next review:

  1. A test to ensure the outputs with the original cache don't change when we switch to this new cache
  2. Some benchmarks -- the point of Paged Attention is performance, so there is no point in merging the code if it ends up being slower and/or consuming more memory 🤗

Refine according to A - General comments: and Add UT to ensure the output is not changed with new cache.

liangan1 avatar Feb 27 '24 11:02 liangan1

@gante Sorry for late reply due to holidays. The static cache is a good choice for greedy search which uses a large buffer to store the past key/value state add removes the concat overhead. But for the beam search, there are also reorder_cache overheads except for the concat overhead in the attention module, besides, the prompt can be shared for beam search to reduce the first token latency, especially for large batch or long sequence. The prompt sharing has the following advantages:

  • Only 1/beam_size memory consumption than the native implementation.
  • Better cache locality with prompt reuse.

In this PR, we enable paged attention which can apply prompt sharing. The first token can get about 3x performance improvement for beam=4 & input_len = 1024, but the next token get performance regression due to the token cache is stored discrete(reshape_and_cache) and the SPDA kernel can't support this format, So the low efficient way is to firstly copy the discrete key/value tokens state into cached key/value states(past_key_value.get_entire_context_states(key_states, value_states) and then call SDPA op. I make a quick validation with the paged attention kernel from Intel Extension for Pytorch in my local CPU machine and paged attention format aware SDPA kernel can obviously speedup the next token inference.

The code change like following: image

The result('sdpa' attn_implementation, imperative mode) as following:

  • top is dynamic cache and bottom is the paged attention
  • ~50ms latency reduction for the next tokens with paged attention aware sdpa kernel:
  • 3x speedup for the first token with prompt sharing.

image

So, if we can merge this PR into transformers, we will cooperate with PyTorch team to enable the related paged attention operators in PyTorch and PyTorch team is also pleasure to do this work.

  • reshape_and_cache #which is used to store the past key/value.
  • paged_attention #which dose the scale dot product based the kv_cache format of paged attention.

liangan1 avatar Feb 28 '24 05:02 liangan1

@gante just a soft reminder, can you help to review again.

liangan1 avatar Mar 11 '24 03:03 liangan1

Hi @liangan1 👋 I'm holding the review until this PR is merged, as it might change the API of caches -- https://github.com/huggingface/transformers/pull/29180

gante avatar Mar 21 '24 09:03 gante

Hi @liangan1 👋 I'm holding the review until this PR is merged, as it might change the API of caches -- #29180

Thanks.

liangan1 avatar Mar 22 '24 00:03 liangan1

@liangan1 apologies for the slow reply rate: all caches now share a standard interface, and they are all standalone objects that can be passed in and out of the model :)

We are now at a point were we can accept PRs for new cache objects 🙌 The current prototype assumes the old StaticCache approach, with one object per attention block, which is no longer compatible.

gante avatar May 21 '24 17:05 gante

@liangan1 apologies for the slow reply rate: all caches now share a standard interface, and they are all standalone objects that can be passed in and out of the model :)

We are now at a point were we can accept PRs for new cache objects 🙌 The current prototype assumes the old StaticCache approach, with one object per attention block, which is no longer compatible.

Thanks for your info. we are also working on the Pagedattention design to Torchao project. And it is compatible to the semantics abstraction of transformers Will refine this PR ASAP.

liangan1 avatar May 22 '24 01:05 liangan1