vllm icon indicating copy to clipboard operation
vllm copied to clipboard

[NVIDIA] Add Cutlass MLA backend

Open kaixih opened this issue 8 months ago • 19 comments

This PR introduces the CUTLASS_MLA_VLLM_V1 backend, enabling support for ops.cutlass_mla_decode() on NVIDIA Blackwell GPUs.

It also includes performance results using DeepSeek-V3 on 8×B200 GPUs under DP+EP parallelism settings, which delivers ~17% improved throughput.

# With default triton backend:
============ Serving Benchmark Result ============
Successful requests:                     2989
Benchmark duration (s):                  1046.01
Total input tokens:                      2989000
Total generated tokens:                  2989000
Request throughput (req/s):              2.86
Output token throughput (tok/s):         2857.52
Total Token throughput (tok/s):          5715.04
---------------Time to First Token----------------
Mean TTFT (ms):                          200716.51
Median TTFT (ms):                        199463.35
P99 TTFT (ms):                           395239.25
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          826.04
Median TPOT (ms):                        826.20
P99 TPOT (ms):                           1001.39
---------------Inter-token Latency----------------
Mean ITL (ms):                           826.04
Median ITL (ms):                         648.89
P99 ITL (ms):                            8337.69
==================================================

With cutlass_mla backend:
============ Serving Benchmark Result ============
Successful requests:                     2989
Benchmark duration (s):                  881.52
Total input tokens:                      2989000
Total generated tokens:                  2989000
Request throughput (req/s):              3.39
Output token throughput (tok/s):         3390.73
Total Token throughput (tok/s):          6781.46
---------------Time to First Token----------------
Mean TTFT (ms):                          190244.11
Median TTFT (ms):                        189563.96
P99 TTFT (ms):                           372713.07
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          685.60
Median TPOT (ms):                        686.96
P99 TPOT (ms):                           858.01
---------------Inter-token Latency----------------
Mean ITL (ms):                           685.60
Median ITL (ms):                         518.56
P99 ITL (ms):                            7738.23
==================================================

To repro the results:

# Server side with triton backend (Plz use VLLM_ATTENTION_BACKEND=CUTLASS_MLA_VLLM_V1 for cutlass backend):
VLLM_LOGGING_LEVEL=DEBUG \
VLLM_WORKER_MULTIPROC_METHOD=spawn \
  vllm serve deepseek-ai/DeepSeek-V3 \
    --trust-remote-code \
    --max-model-len=2048 \
    --block-size=128 \
    --max-num-seqs=512 \
    --gpu_memory_utilization=0.97 \
    --data-parallel-size $NUM_GPUS --enable-expert-parallel \
    --disable-log-requests

# client side:
python $VLLM_PATH/benchmarks/benchmark_serving.py \
  --model deepseek-ai/DeepSeek-V3 \
  --dataset-name random \
  --ignore-eos \
  --num-prompts 3000 \
  --max-concurrency 3000 \
  --random-input-len 1000 \
  --random-output-len 1000

cc. @kushanam

kaixih avatar May 04 '25 07:05 kaixih

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

github-actions[bot] avatar May 04 '25 07:05 github-actions[bot]

The perf is looking really good! thanks for the contribution!

Do you mind doing accuracy checks

VLLM_ATTENTION_BACKEND=CUTLASS_MLA_VLLM_V1 VLLM_USE_V1=1 lm-eval --model vllm --model_args pretrained=deepseek-ai/DeepSeek-V2-Lite-Chat,tensor_parallel_size=2,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=True,max_model_len=16384 --task gsm8k --num_fewshot 5  --batch_size auto

Do you know any commend to test a model with num_heads = 128? And probably no TP.

kaixih avatar May 06 '25 16:05 kaixih

Do you know any commend to test a model with num_heads = 128? And probably no TP.

Not that im aware of :/ this is the smallest MLA model I am aware of

LucasWilkinson avatar May 06 '25 22:05 LucasWilkinson

This is the smallest model with MLA ( @tlrmchlsmth found it the other day) https://huggingface.co/deepseek-ai/deepseek-vl2-tiny

mgoin avatar May 06 '25 23:05 mgoin

This is the smallest model with MLA ( @tlrmchlsmth found it the other day) https://huggingface.co/deepseek-ai/deepseek-vl2-tiny

just want to clarify that the way I found it was grep -r deepseek tests | grep tiny

tlrmchlsmth avatar May 07 '25 01:05 tlrmchlsmth

Ah I don't think it's an MLA model :/

    "kv_lora_rank": null,
    ...
    "use_mla": false,

LucasWilkinson avatar May 07 '25 02:05 LucasWilkinson

Edit: oh and ideally id still like to see accuracy numbers...

@LucasWilkinson this DeepSeek-V2-Lite-Chat only has attention head number == 16 and --tp=2 is not ok. Any advice on a model using head num == 128?

kaixih avatar May 08 '25 21:05 kaixih

other than I do think we should turn it on by default for Blackwell, Any reason not to?

My main concern is that the CUTLASS MLA kernel has more limited support compared to the Triton version. For example, num_heads must be 128, and the page table must be padded to a multiple of 128 / page_size. We're working on expanding the supported use cases, but until then, we'd prefer to keep this as an optional backend. What do you think?

kaixih avatar May 08 '25 22:05 kaixih

other than I do think we should turn it on by default for Blackwell, Any reason not to?

My main concern is that the CUTLASS MLA kernel has more limited support compared to the Triton version. For example, num_heads must be 128, and the page table must be padded to a multiple of 128 / page_size. We're working on expanding the supported use cases, but until then, we'd prefer to keep this as an optional backend. What do you think?

Oh interesting, can we turn it on my default for models with head dim 128? as for block size I think we can force it similar to to FlashMLA (https://github.com/vllm-project/vllm/blob/5e6f93948449b8095e8eef5c3d99a8726e216a44/vllm/platforms/cuda.py#L145-L158), we plan to eventually handle this more gracefully (i.e. better support of the attention backend being able to specify the block size it wants if the user has not specified one)

Im ok landing this for now and opening a new PR to turn it on by default; but I think we still should, we should strive to make the best they can be.

LucasWilkinson avatar May 09 '25 03:05 LucasWilkinson

Edit: oh and ideally id still like to see accuracy numbers...

@LucasWilkinson this DeepSeek-V2-Lite-Chat only has attention head number == 16 and --tp=2 is not ok. Any advice on a model using head num == 128?

Ah the only one I know of is DeepSeek V3 or R1, if you have a system that can run that then those accuracy numbers would be good to see!

LucasWilkinson avatar May 09 '25 03:05 LucasWilkinson

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @kaixih.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar May 10 '25 23:05 mergify[bot]

Got the lm_eval output with the cutlass backend, which matches the triton backend:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|_  | 0.96|_  |0.0197|
|     |       |strict-match    |     5|exact_match|_  | 0.96|_  |0.0197|

How to repro:

# server side: bash run.sh [cutlass]
if [[ "$1" == "cutlass" ]]; then
  export VLLM_ATTENTION_BACKEND=CUTLASS_MLA_VLLM_V1
fi

VLLM_LOGGING_LEVEL=DEBUG \
VLLM_WORKER_MULTIPROC_METHOD=spawn \
  vllm serve deepseek-ai/DeepSeek-V3 \
    --trust-remote-code \
    --max-model-len=4096 \
    --block-size=128 \
    --max-num-seqs=512 \
    --gpu_memory_utilization=0.97 \
    --data-parallel-size $NUM_GPUS --enable-expert-parallel \
    --disable-log-requests &

# client side
lm_eval --model local-completions --tasks gsm8k \
    --model_args model=deepseek-ai/DeepSeek-V3,base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=5,max_retries=3,tokenized_requests=False \
    --limit 100

kaixih avatar May 11 '25 04:05 kaixih

The updated throughput after rebasing:

triton:
============ Serving Benchmark Result ============
Successful requests:                     3000
Benchmark duration (s):                  847.99
Total input tokens:                      2997000
Total generated tokens:                  3000000
Request throughput (req/s):              3.54
Output token throughput (tok/s):         3537.76
Total Token throughput (tok/s):          7071.98
---------------Time to First Token----------------
Mean TTFT (ms):                          155845.04
Median TTFT (ms):                        157803.49
P99 TTFT (ms):                           304408.28
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          686.05
Median TPOT (ms):                        685.18
P99 TPOT (ms):                           824.89
---------------Inter-token Latency----------------
Mean ITL (ms):                           686.05
Median ITL (ms):                         562.49
P99 ITL (ms):                            6502.67
==================================================

cutlass:
============ Serving Benchmark Result ============
Successful requests:                     3000
Benchmark duration (s):                  669.38
Total input tokens:                      2997000
Total generated tokens:                  3000000
Request throughput (req/s):              4.48
Output token throughput (tok/s):         4481.75
Total Token throughput (tok/s):          8959.01
---------------Time to First Token----------------
Mean TTFT (ms):                          154187.64
Median TTFT (ms):                        156351.06
P99 TTFT (ms):                           302269.57
-----Time per Output Token (excl. 1st token)------                                                                                                                                             Mean TPOT (ms):                          509.17
Median TPOT (ms):                        507.58                                                                                                                                                P99 TPOT (ms):                           649.29
---------------Inter-token Latency----------------
Mean ITL (ms):                           509.17
Median ITL (ms):                         370.06
P99 ITL (ms):                            6451.99
==================================================

kaixih avatar May 11 '25 07:05 kaixih

we plan to eventually handle this more gracefully (i.e. better support of the attention backend being able to specify the block size it wants if the user has not specified one) Im ok landing this for now and opening a new PR to turn it on by default; but I think we still should, we should strive to make the best they can be.

I agree — let's land this PR first, and then improve the attention backend picker in a more graceful way. I briefly looked into the code, and supporting this properly would require significant changes — for example, the function would need access to model_config (for num_heads), cache_config (for block_size), and parallel_config (for TP/DP). I think that level of change is better handled in a separate PR.

kaixih avatar May 11 '25 07:05 kaixih

This PR is pretty red... these tests aren't all failing on main are they?

tlrmchlsmth avatar May 16 '25 02:05 tlrmchlsmth

This PR is pretty red... these tests aren't all failing on main are they?

I am checking the logs but nothing seems to be related to my change. Can you please advise? @tlrmchlsmth

kaixih avatar May 16 '25 05:05 kaixih

@tlrmchlsmth ok. I think i have done all I can do. Basically I have reverted all my changes that are for the existing files (only leave one that adding a new file). And you can see the tests are still red. Can you help?

kaixih avatar May 16 '25 23:05 kaixih

Force push the original PR.

kaixih avatar May 16 '25 23:05 kaixih

@tlrmchlsmth Can you please advise? ^^

kaixih avatar May 20 '25 03:05 kaixih

I think we have fixed the CI, could you rebase the PR again?

houseroad avatar May 31 '25 08:05 houseroad

@houseroad Thanks, I've just rebased. However, the Lint and Deploy Charts / lint-and-deploy (pull_request) check is failing, and it appears to be unrelated to the changes in this PR. Can you advise?

kaixih avatar Jun 02 '25 03:06 kaixih

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @kaixih.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Jun 03 '25 16:06 mergify[bot]