vllm icon indicating copy to clipboard operation
vllm copied to clipboard

[Misc] Add attention sinks

Open felixzhu555 opened this issue 11 months ago • 5 comments

Overview

This PR adds experimental support for attention sinks (#1304), based on this paper and repo. Support is currently limited to RoPE and ALiBi models (e.g. Llama, Mistral/Mixtral, Falcon, Bloom, MPT). The attention sink is hard-coded as the first block of tokens in a sequence.

Usage

Set use_attention_sinks=True when instantiating LLM or LLMEngine, or set the --use-attention-sinks CLI argument. Also set enforce_eager=True (attention sinks currently does not work with CUDA graphs), and ensure the attention backend being used is FlashAttention, XFormers, or FlashInfer (WIP).

Background

Experiments show that the attention mechanism heavily attends to the first few tokens of the sequence being completed, regardless of what the tokens are. Once sequence length exceeds the context length of a model, and we start evicting tokens from the beginning of the KV cache (in a sliding window fashion), the model will generate garbage (high perplexity).

This is where attention sinks come in. By always preserving the KVs for the first few tokens of the sequence while using a sliding window approach for the rest of the KV cache, the model can continue to generate sensible output (low perplexity). Theoretically, the model can stream indefinitely, as long as cache eviction is handled properly. Note the sliding window length is the model's context length.

Example

Suppose our model's context length is 2048, which equals 128 blocks of 16 tokens. Let's pass in a prompt of 2000 tokens. For the next 48 generated tokens, nothing changes; we end up filling 128 blocks so far.

Normally, vLLM forces generation to stop here since the model's context length has been reached. However, using attention sinks we bypass this stopping condition and keep generating.

At the next decode, we are writing the 2049th token to the cache and computing the 2050th token (1-based indexing). Here, we edit the block table to be [block_table[0]] + block_table[2:], where we effectively ignore the 2nd block while retaining the 1st block, which is our attention sink. Notice how the block table is still length 128 because the 129th block was just allocated for token 2049. This modified block table is then used in the attention kernel.

Every 16th decode that follows will ignore an additional block, but always retain the 1st block as the sink.

Modifications

This PR adds a StreamingAttentionSink layer that computes attention using modified block tables with the "sink" block concatenated with the remaining sliding window blocks. In the RoPE case, we always store pre-rope keys into the cache, and extra work must be done at every decode to rotate all keys for a sequence based on their new positions in the cache. Note: due to this extra work, using attention sinks incurs a significant drop in tokens/s for RoPE models (around 50-70% for Llama).

use_attention_sinks is now an argument to LLMEngine, which passes it to the model runner and injects attention sinks into the model's modules. On every forward call of the model's attention layer, normal attention logic is replaced by StreamingAttentionSink logic.

The scheduler evicts (frees) a block (the "ignored" block) whenever a new block is allocated past the model's context length, such that the total number of used blocks is capped at max_model_len // block_size.

Future Work

  • Other attention backends: ROCMFlashAttention, torch SDPA
  • Support LoRA: LoRA requests with attention sinks is currently untested.
  • Integrate with speculative decoding: StreamingAttentionSink assumes only 1 token is generated every decode.
  • Integrate with prefix caching: StreamingAttentionSink directly edits the block table for every decode (past the context length), so the hash table for prefix caching cannot be used currently.

felixzhu555 avatar Mar 19 '24 23:03 felixzhu555

Hi, @felixzhu555 . it is https://arxiv.org/abs/2309.17453 right?

rkooo567 avatar Mar 22 '24 01:03 rkooo567

Yep, trying to implement the logic from that paper. Their repo is https://github.com/mit-han-lab/streaming-llm.

felixzhu555 avatar Mar 22 '24 02:03 felixzhu555

We need to @rlouf to the PR the guy in charge of outline, it seems that your PR is failing on the guided part. I'll try to bring him in to help

jqueguiner avatar Mar 22 '24 04:03 jqueguiner

To speed up the CI queue for #5905, I've cancelled the distributed tests for the latest CI run in this PR since they won't pass anyway until #5905 has been merged. Please merge main into your branch after that happens so that the CI can pass once again.

DarkLight1337 avatar Jun 28 '24 06:06 DarkLight1337

@felixzhu555 Hi, this is pr still in progress and should I expect it will be merged?

hustxiayang avatar Oct 18 '24 21:10 hustxiayang

hi @hustxiayang, sorry this PR likely won't get merged, it remains an experimental prototype based on an older version of vLLM. After the ongoing engine refactor is complete, the memory manager in vllm should become more extensible and attention sinks can be supported more easily, at which time we can probably open a new PR.

felixzhu555 avatar Oct 20 '24 01:10 felixzhu555

@felixzhu555 thanks a lot for your clarification!

hustxiayang avatar Oct 24 '24 01:10 hustxiayang