vllm icon indicating copy to clipboard operation
vllm copied to clipboard

[WIP] Add FlexAttention to V1

Open drisspg opened this issue 7 months ago • 6 comments
trafficstars

Summary

This PR adds FlexAttention as a new unified_attention backend for the V1 engine.

This requires torch 2.7+ since we fixed a number of dynamic shapes issues that show up by default here.

Design

FlexAttention is broken up into two distinct phases, block mask creation and the call to forward. For most Transformers they N attention layers share a common attention pattern and thus we can amortize the cost of block mask creation over the N attention layers. This lends itself pretty nicley to the Metadata Builder for the UA OP.

Majority of the work here is to build the correct BlockMask.

The current BlockTable is of the form: image

This block table is a map from logical KV pages to Physical KV pages in the paged KV cache. It has a size of MAX_REQS x (MAX_SEQ_LEN//PAGE_SIZE).

FlexAttention has no notion of a PageTable and we have to build out inverse mapping from Physical Pages (the full paged KV Cache is input to kernel) to Logical Indices. We then use these logical indices to determine if we should compute attention for a query x KV pair.


    Logical to Physical (Original block_table):
    ┌───────────────────────────────────────────┐
    │ Request 0:                                │
    │                                           │
    │ Logical Blocks:  0  1  2  3  4  5  6  7   │
    │                  │  │  │  │  │  │  │  │   │
    │                  v  v  v  v  v  v  v  v   │
    │ Physical Blocks: 3  5  1  7  4  2  0  6   │
    └───────────────────────────────────────────┘

    This function creates the inverse mapping:

    Physical to Logical (Inverse mapping):
    ┌───────────────────────────────────────────┐
    │ Request 0:                                │
    │                                           │
    │ Physical Blocks: 0  1  2  3  4  5  6  7   │
    │                  │  │  │  │  │  │  │  │   │
    │                  v  v  v  v  v  v  v  v   │
    │ Logical Blocks:  6  2  5  0  4  1  7  3   │
    └───────────────────────────────────────────┘

  • Uses more memory than the page table
  • Required memory: MAX_REQS × NUM_PAGES
  • For smaller models:
    • Number of pages can be up to 178,375
    • Calculation: (total tokens) / default_page_size = 2,854,000/16
    • Typical max_seq_len of 2048 = 128 Pages

Setting up a Generic Physical to Logical re-writer

Once we have this Physical to Logical Map we can abstract this away from different logical mask_mods. We do this w/ this function: https://github.com/vllm-project/vllm/pull/16078/files#diff-0310608cf47330020e617d94f28ce469e6e802e291f33ce4bce90e22e11cc7e5R196

By default this type of attention can be seen as a document-packed or var-seq len transformation so that all separate queries are splatted into 1 super sequence. We use have a small lookup from physical q_idx to req_id.

We use this req_id to get the inverse page_table. And with this inverse page table we isolate attention to valid sequences (valid blocks and < current seq_len)

We then adjust the q_idx by the offset - which is 0 during prefill.

Once all thats done, we can have a pure "logical mask mod" which by default - and for most models will be

def causal(b, h, q_idx, kv_idx):
   return q_idx >= kv_idx

Adding New Variants

The nice thing about the above setup is that it makes adding new variants simpler since we have a generic paged+packed rewriter. And users will only need register a simple logical mod here: https://github.com/drisspg/vllm/blob/891345dd545ad86ca57163f34d4ea7696610dea3/vllm/v1/attention/backends/flex_attention.py#L177

For score_mods I didnt put much effort and mostly disabled for now since that should be an easy follow up. For instance if you wanted to support tanh softcapping we would just need to pass in a tanh_score_mod here: https://github.com/drisspg/vllm/blob/891345dd545ad86ca57163f34d4ea7696610dea3/vllm/v1/attention/backends/flex_attention.py#L433

Performance Gaps (for now)

Trace for a baseline enforce_eager, but flex_components compiled for Qwen 1.5b w/ full KVCache e.g. 2.84 Million Tokens single GPU

Trace: https://fburl.com/1uo5h5uy TLP: https://fburl.com/z5u609g8

In this case we see 2 frames and 1 recompile each for these. And in this case the recompiles are coming from difference in Query length size between prefill and decode. This is exactly what mark_dynamic should be able to solve but since we have a dynamic integer "Q_LEN" we need to use TORCH_COMPILE_DYNAMIC_SOURCES but I couldn't get this to work so punting for now and having the two recompiles.

In this case it is taking around 3.7 seconds before we settle into a steady state (for the given requests) Screenshot 2025-04-22 at 7 23 22 PM

Lets zoom in on the steady state: Screenshot 2025-04-23 at 5 11 08 PM

drisspg avatar Apr 04 '25 22:04 drisspg