vllm icon indicating copy to clipboard operation
vllm copied to clipboard

[Attention] MLA with chunked prefill

Open LucasWilkinson opened this issue 10 months ago • 8 comments

Need to do more benchmarking to see if this makes sense to be on by default in V0, but lays the groundwork for a V1 implementation. (https://github.com/vllm-project/vllm/pull/13111 may help performance)

lm_eval --model vllm --model_args pretrained=deepseek-ai/DeepSeek-V2-Lite-Chat,tensor_parallel_size=2,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=True,max_model_len=16384,enable_chunked_prefill=False --task gsm8k --num_fewshot=5 --limit 100

vllm (pretrained=deepseek-ai/DeepSeek-V2-Lite-Chat,tensor_parallel_size=2,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=True,max_model_len=16384,enable_chunked_prefill=False), gen_kwargs: (None), limit: 100.0, num_fewshot: 5, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.66|±  |0.0476|
|     |       |strict-match    |     5|exact_match|↑  | 0.66|±  |0.0476|


lm_eval --model vllm --model_args pretrained=deepseek-ai/DeepSeek-V2-Lite-Chat,tensor_parallel_size=2,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=True,max_model_len=16384,enable_chunked_prefill=True --task gsm8k --num_fewshot=5 --limit 100


vllm (pretrained=deepseek-ai/DeepSeek-V2-Lite-Chat,tensor_parallel_size=2,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=True,max_model_len=16384,enable_chunked_prefill=True), gen_kwargs: (None), limit: 100.0, num_fewshot: 5, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.66|±  |0.0476|
|     |       |strict-match    |     5|exact_match|↑  | 0.66|±  |0.0476|

Shout-out to @pathorn for assisting with hardening this PR

Future work:

  • [x] Allocate the worst case result of self.kv_b_proj(kv_c_normed) in the profile run
  • [ ] https://github.com/vllm-project/vllm/pull/12639#discussion_r1956673731
  • [ ] Improved algo for allocating workspace amongst batch elements
  • [ ] Improve how the workspace is allocated

LucasWilkinson avatar Feb 01 '25 04:02 LucasWilkinson

👋 Hi! Thank you for contributing to the vLLM project. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

github-actions[bot] avatar Feb 01 '25 04:02 github-actions[bot]

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @LucasWilkinson.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Feb 06 '25 05:02 mergify[bot]

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @LucasWilkinson.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Feb 07 '25 03:02 mergify[bot]

NOTE: @pathorn found a bug with FP8 in R1, will notify here when resolved

Edit this has been resolved

LucasWilkinson avatar Feb 14 '25 06:02 LucasWilkinson

Removed the V1 tag because although it does move some code out of the v1 flash attention backend, I didn't want anyone to get the impression that this PR adds support for MLA

tlrmchlsmth avatar Feb 14 '25 21:02 tlrmchlsmth

@tlrmchlsmth

I wonder if it would be better to detect if we are in the profile run and allocate temporary tensors of size equal to the upper limit on the workspace required, instead of what we are doing now. It sounds like there might be an edge case where we run out of memory, and if so we should address before landing

this should be addressed by: https://github.com/vllm-project/vllm/pull/12639/commits/1c595972ef844b97538cd8f93ad38c3904287cf0

without this commit I get:

model weights take 84.11GiB; non_torch_memory takes 5.13GiB; PyTorch activation peak memory takes 0.19GiB; the rest of the memory reserved for KV Cache is 36.41GiB.

with it I get:

model weights take 84.11GiB; non_torch_memory takes 5.13GiB; PyTorch activation peak memory takes 1.17GiB; the rest of the memory reserved for KV Cache is 35.42GiB.

LucasWilkinson avatar Feb 15 '25 01:02 LucasWilkinson

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @LucasWilkinson.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Feb 15 '25 01:02 mergify[bot]

NOTE: @pathorn found a bug when stress testing R1, will notify here when resolved

https://vllm-dev.slack.com/archives/C08AD2B5HH8/p1739521144253459?thread_ts=1739486497.566799&cid=C08AD2B5HH8

Edit: should be resolved by https://github.com/vllm-project/vllm/commit/920ecc69ab68849eeed14204a8a6fa88179b684c#diff-00753a3c1f378f8b8c60e9eb10b94c3cbbfcea74fca6e66712e5d4ae360f6741

LucasWilkinson avatar Feb 15 '25 01:02 LucasWilkinson

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @LucasWilkinson.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Feb 21 '25 01:02 mergify[bot]

Hi @LucasWilkinson thx for ur wonderful work! I am a little confused on the backend that got from get_attn_backend_cls. Since we should set VLLM_USE_V1 to use chunked prefill, from here, we would get vllm.v1.attention.backends.flash_attn.FlashAttentionBackend instead of vllm.attention.backends.triton_mla.TritonMLABackend?

ZhongYingMatrix avatar Feb 23 '25 10:02 ZhongYingMatrix