vllm
vllm copied to clipboard
[WIP][Attention] FlashAttn MLA
Use latest FlashAttention code to compute decode MQA in MLA
Based on https://github.com/vllm-project/vllm/pull/13111 merge that first
👋 Hi! Thank you for contributing to the vLLM project.
💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.
Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.
To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.
🚀
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @LucasWilkinson.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @LucasWilkinson.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
UPDATED INFO
NOTE: This PR requires vllm-project/flash-attention PR #84. Merge that and update GIT_TAG in vllm_flash_attn.cmake before merging!
^EDIT: Done.
Test - correctness
pytest tests/v1/attention/test_mla_backends.py
VLLM_ATTENTION_BACKEND=<backend> lm_eval --model vllm --model_args '{"pretrained": "deepseek-ai/DeepSeek-V2-Lite-Chat", "trust_remote_code": true, "kv_cache_dtype": "auto"}' --tasks gsm8k --batch_size auto
FlashMLA baseline (<backend> = FLASHMLA):
|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.6702|± |0.0129|
| | |strict-match | 5|exact_match|↑ |0.6619|± |0.0130|
FlashAttention MLA (<backend> = FLASH_ATTN_MLA):
|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.6687|± |0.0130|
| | |strict-match | 5|exact_match|↑ |0.6573|± |0.0131|
Test - throughput (end-to-end, query length 1 decodes)
(Running on 4x H100)
VLLM_ATTENTION_BACKEND=<backend> vllm bench throughput --model=RedHatAI/DeepSeek-Coder-V2-Instruct-FP8 --dataset-name=random --input-len=8192 --output-len=1024 --num-prompts=1000 --kv-cache-dtype=auto --tensor-parallel-size 4 --enable-expert-parallel --max_model_len=16384
FLASH_ATTN_MLA: Throughput: 0.40 requests/s, 3715.94 total tokens/s, 412.92 output tokens/s
FLASHMLA: Throughput: 0.37 requests/s, 3381.36 total tokens/s, 375.73 output tokens/s
The decode threshold has now been tuned by sweeping the two pipelines over query lengths and batch sizes, results below:
Two policies are examined: one which increases the decode threshold to 512, and one which computes the decode threshold as a quadratic function of batch size. The constant decode threshold of 512 is adopted for simplicity.
@LucasWilkinson Thanks for the comments! I've addressed them.
LGTM