vllm icon indicating copy to clipboard operation
vllm copied to clipboard

[WIP][Attention] FlashAttn MLA

Open LucasWilkinson opened this issue 9 months ago • 1 comments

Use latest FlashAttention code to compute decode MQA in MLA

Based on https://github.com/vllm-project/vllm/pull/13111 merge that first

LucasWilkinson avatar Mar 05 '25 05:03 LucasWilkinson

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

github-actions[bot] avatar Mar 05 '25 05:03 github-actions[bot]

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @LucasWilkinson.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Mar 26 '25 03:03 mergify[bot]

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @LucasWilkinson.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Apr 22 '25 21:04 mergify[bot]

UPDATED INFO

NOTE: This PR requires vllm-project/flash-attention PR #84. Merge that and update GIT_TAG in vllm_flash_attn.cmake before merging!

^EDIT: Done.

Test - correctness

pytest tests/v1/attention/test_mla_backends.py VLLM_ATTENTION_BACKEND=<backend> lm_eval --model vllm --model_args '{"pretrained": "deepseek-ai/DeepSeek-V2-Lite-Chat", "trust_remote_code": true, "kv_cache_dtype": "auto"}' --tasks gsm8k --batch_size auto

FlashMLA baseline (<backend> = FLASHMLA):

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.6702|±  |0.0129|
|     |       |strict-match    |     5|exact_match|↑  |0.6619|±  |0.0130|

FlashAttention MLA (<backend> = FLASH_ATTN_MLA):

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.6687|±  |0.0130|
|     |       |strict-match    |     5|exact_match|↑  |0.6573|±  |0.0131|

Test - throughput (end-to-end, query length 1 decodes)

(Running on 4x H100)

VLLM_ATTENTION_BACKEND=<backend> vllm bench throughput --model=RedHatAI/DeepSeek-Coder-V2-Instruct-FP8 --dataset-name=random --input-len=8192 --output-len=1024 --num-prompts=1000 --kv-cache-dtype=auto --tensor-parallel-size 4 --enable-expert-parallel --max_model_len=16384

FLASH_ATTN_MLA: Throughput: 0.40 requests/s, 3715.94 total tokens/s, 412.92 output tokens/s FLASHMLA: Throughput: 0.37 requests/s, 3381.36 total tokens/s, 375.73 output tokens/s

MatthewBonanni avatar Aug 26 '25 02:08 MatthewBonanni

The decode threshold has now been tuned by sweeping the two pipelines over query lengths and batch sizes, results below:

mla_multibatch_analysis

Two policies are examined: one which increases the decode threshold to 512, and one which computes the decode threshold as a quadratic function of batch size. The constant decode threshold of 512 is adopted for simplicity.

mla_policy_speedup_comparison

MatthewBonanni avatar Aug 27 '25 15:08 MatthewBonanni

@LucasWilkinson Thanks for the comments! I've addressed them.

MatthewBonanni avatar Aug 27 '25 18:08 MatthewBonanni

LGTM

robertgshaw2-redhat avatar Aug 28 '25 02:08 robertgshaw2-redhat