vllm icon indicating copy to clipboard operation
vllm copied to clipboard

[Core] Support full cuda graph in v1

Open chanh opened this issue 8 months ago • 13 comments

Summary

Support capturing a single CUDA graph for the entire model's forward pass, instead of piecewise graphs. This requires creating persistent buffers to make attention graphable. Credit to @tlrmchlsmth for the original implementation.

Limitations:

  1. This only works with V1 + FA3, since FA2 currently is not graphable due to an optimization for GQA.
  2. This doesn't work with Cascade Attention.

Work in progress:

  1. Investigating changes needed to make this work with Llama4 / local attention

This reduces median TPOT by 7% for small models like Qwen 2.5 1.5B.

Before

With piecewise, there are multiple kernel launches per layer, with more gaps between the kernel execution (13ms time to decide one token in profiling mode): Screenshot 2025-04-04 at 12 04 24 PM

============ Serving Benchmark Result ============
Successful requests:                     100       
Benchmark duration (s):                  103.15    
Total input tokens:                      100000    
Total generated tokens:                  10000     
Request throughput (req/s):              0.97      
Output token throughput (tok/s):         96.95     
Total Token throughput (tok/s):          1066.46   
---------------Time to First Token----------------
Mean TTFT (ms):                          29.08     
Median TTFT (ms):                        28.89     
P99 TTFT (ms):                           36.17     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          5.75      
Median TPOT (ms):                        5.75      
P99 TPOT (ms):                           6.00      
---------------Inter-token Latency----------------
Mean ITL (ms):                           5.75      
Median ITL (ms):                         5.70      
P99 ITL (ms):                            6.58      
==================================================

After

There is now a single kernel launch, with almost no gaps between kernel execution (6ms time to decode one token in profiling mode): Screenshot 2025-04-04 at 12 05 54 PM

============ Serving Benchmark Result ============
Successful requests:                     100       
Benchmark duration (s):                  103.10    
Total input tokens:                      100000    
Total generated tokens:                  10000     
Request throughput (req/s):              0.97      
Output token throughput (tok/s):         96.99     
Total Token throughput (tok/s):          1066.92   
---------------Time to First Token----------------
Mean TTFT (ms):                          29.52     
Median TTFT (ms):                        30.47     
P99 TTFT (ms):                           39.97     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          5.31      
Median TPOT (ms):                        5.33      
P99 TPOT (ms):                           5.56      
---------------Inter-token Latency----------------
Mean ITL (ms):                           5.31      
Median ITL (ms):                         5.27      
P99 ITL (ms):                            6.18      
==================================================

** Above benchmarks performed with:

VLLM_FLASH_ATTN_VERSION=3 VLLM_USE_V1=1 vllm serve Qwen/Qwen2.5-1.5B-Instruct  --enable-prefix-caching --dtype float16 --disable-log-requests -O3 (or -O4)

vllm bench serve \
        --model Qwen/Qwen2.5-1.5B-Instruct \
        --request-rate 1 \
        --num-prompts 100 \
        --random-input-len 1000 \
        --random-output-len 100 \
        --tokenizer Qwen/Qwen2.5-1.5B-Instruct \
        --ignore-eos

chanh avatar Apr 04 '25 20:04 chanh

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

github-actions[bot] avatar Apr 04 '25 20:04 github-actions[bot]

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @chanh.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Apr 08 '25 06:04 mergify[bot]

@chanh thanks for the PR, I have tested llama 8b on my side with your PR and I see ~7% improvement for TPOT. Great work!

Before PR:

============ Serving Benchmark Result ============
Successful requests:                     50        
Benchmark duration (s):                  45.05     
Total input tokens:                      25600     
Total generated tokens:                  12800     
Request throughput (req/s):              1.11      
Output token throughput (tok/s):         284.11    
Total Token throughput (tok/s):          852.34    
---------------Time to First Token----------------
Mean TTFT (ms):                          22.43     
Median TTFT (ms):                        22.10     
P99 TTFT (ms):                           27.48     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          7.63      
Median TPOT (ms):                        7.63      
P99 TPOT (ms):                           7.77      
---------------Inter-token Latency----------------
Mean ITL (ms):                           7.63      
Median ITL (ms):                         7.61      
P99 ITL (ms):                            8.45      
==================================================

After PR:

============ Serving Benchmark Result ============
Successful requests:                     50        
Benchmark duration (s):                  44.93     
Total input tokens:                      25600     
Total generated tokens:                  12800     
Request throughput (req/s):              1.11      
Output token throughput (tok/s):         284.88    
Total Token throughput (tok/s):          854.64    
---------------Time to First Token----------------
Mean TTFT (ms):                          22.72     
Median TTFT (ms):                        22.93     
P99 TTFT (ms):                           27.49     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          7.14      
Median TPOT (ms):                        7.14      
P99 TPOT (ms):                           7.28      
---------------Inter-token Latency----------------
Mean ITL (ms):                           7.14      
Median ITL (ms):                         7.12      
P99 ITL (ms):                            8.05      
==================================================

alexm-redhat avatar Apr 08 '25 19:04 alexm-redhat

Work in progress:

  1. Investigating changes needed to make this work with Llama4 / local attention

just a heads up @zou3519

sarckk avatar Apr 09 '25 03:04 sarckk

@chanh thanks for the PR, I have tested llama 8b on my side with your PR and I see ~7% improvement for TPOT. Great work!

Before PR:

============ Serving Benchmark Result ============
Successful requests:                     50        
Benchmark duration (s):                  45.05     
Total input tokens:                      25600     
Total generated tokens:                  12800     
Request throughput (req/s):              1.11      
Output token throughput (tok/s):         284.11    
Total Token throughput (tok/s):          852.34    
---------------Time to First Token----------------
Mean TTFT (ms):                          22.43     
Median TTFT (ms):                        22.10     
P99 TTFT (ms):                           27.48     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          7.63      
Median TPOT (ms):                        7.63      
P99 TPOT (ms):                           7.77      
---------------Inter-token Latency----------------
Mean ITL (ms):                           7.63      
Median ITL (ms):                         7.61      
P99 ITL (ms):                            8.45      
==================================================

After PR:

============ Serving Benchmark Result ============
Successful requests:                     50        
Benchmark duration (s):                  44.93     
Total input tokens:                      25600     
Total generated tokens:                  12800     
Request throughput (req/s):              1.11      
Output token throughput (tok/s):         284.88    
Total Token throughput (tok/s):          854.64    
---------------Time to First Token----------------
Mean TTFT (ms):                          22.72     
Median TTFT (ms):                        22.93     
P99 TTFT (ms):                           27.49     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          7.14      
Median TPOT (ms):                        7.14      
P99 TPOT (ms):                           7.28      
---------------Inter-token Latency----------------
Mean ITL (ms):                           7.14      
Median ITL (ms):                         7.12      
P99 ITL (ms):                            8.05      
==================================================

Thanks for @alexm-redhat for verifying!

chanh avatar Apr 09 '25 19:04 chanh

@chanh tell me if you need help with extending the tests, I can do it on my side.

alexm-redhat avatar Apr 11 '25 15:04 alexm-redhat

Thanks for the PR! I will review it this weekend (maybe Tyler and Rob, too).

WoosukKwon avatar Apr 11 '25 15:04 WoosukKwon

I ran some latency-focused testing on this PR using LLaMA 3.2 1B Instruct with a small batch size (~1-2) in a highly latency-constrained setting where minimizing CUDA graph launches can significantly improve GPU utilization. Here are the results:

Before PR:

Average latency: 56.82 ms
p50 latency: 53.00 ms
p90 latency: 64.00 ms
p95 latency: 68.00 ms
p99 latency: 82.23 ms

After PR:

Average latency: 50.30 ms
p50 latency: 48.00 ms
p90 latency: 58.00 ms
p95 latency: 61.00 ms
p99 latency: 67.00 ms

This shows a notable improvement across the board, particularly in tail latencies. Great work!

dblincoe avatar Apr 11 '25 16:04 dblincoe

@WoosukKwon would be good to have your quick feedback

alexm-redhat avatar Apr 17 '25 21:04 alexm-redhat

@youkaichao

if we can figure out the conditions, we can try to enable it automatically I think, without introducing a new user interface like level 4 optimization.

To my understanding, there are two reasons why this can't be enabled default

  1. Cascade attention is not supported when the full graph is used
  2. The attention kernel's performance could be lower, because the kernel's scheduling is determined at the capture time and is not adaptive to different configurations.

WoosukKwon avatar Apr 21 '25 06:04 WoosukKwon

there are two reasons why this can't be enabled default

then we should have clear documentation around it, when full-cudagraph can be used.

The attention kernel's performance could be lower, because the kernel's scheduling is determined at the capture time and is not adaptive to different configurations.

I'm worried that if this will lead to a bug, so we need to confirm if FA3 kernel w/ prefill can be used with cudagraph. We only capture cudagraph for num-batched tokens, but they can be mixed prefill / decode, and kv-cache (context length) dimension is not taken into consideration during capture cudagraph.

youkaichao avatar Apr 21 '25 12:04 youkaichao

Actually, what's max_seq_len and max_query_len at the capture time? If it's 0, I guess this could cause a bug.

max_query_len is the num_tokens being passed to _dummy_run which is the cudagraph batch size to capture.

As for max_seq_len it shouldn't matter what it is at capture time because at replay time it's dynamically computed based on the persistent buffer seq_lens.

At least this is my understanding, could be wrong.

chanh avatar Apr 25 '25 09:04 chanh

I think cudagraph and torch.compile are not directly connected things.

In my opinion it is better to introduce standalone option to enable full cudagraph capture.

Also -O4 might be better to reserve for future more aggressive torch.compile mode.

vadiklyutiy avatar Apr 25 '25 18:04 vadiklyutiy

Hmm.... For some reason, I see lower performance for Llama 3.2 1B with the full cuda graphs, compared to piecewise cuda graphs.

WoosukKwon avatar May 05 '25 16:05 WoosukKwon

@WoosukKwon what command were you using to run llama 3.2 1B?

alexm-redhat avatar May 05 '25 18:05 alexm-redhat

@alexm-redhat It's

python benchmarks/benchmark_latency.py --model meta-llama/Llama-3.2-1B --batch-size 1 --input-len 4096 --output-len 50 --no-enable-prefix-caching --compilation-config {"'full_cuda_graph': True"}

I think it makes sense because the full graph capture essentially disables the FlashAttention's heuristics to tune the kernel parameters based on max_seq_len and other information.

WoosukKwon avatar May 05 '25 18:05 WoosukKwon

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @chanh.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar May 06 '25 14:05 mergify[bot]

@chanh we are getting close!! can you please fix the merge conflict? thank you

simon-mo avatar May 07 '25 00:05 simon-mo

@chanh Thanks for pushing this through!

tlrmchlsmth avatar May 07 '25 13:05 tlrmchlsmth

I think we may need to disable ahead-of-time scheduling for FA3 when using full cuda-graph:

https://github.com/vllm-project/vllm/blob/1a6af1453d2077832c3d5e8bcd60a5ef6a95e46b/vllm/v1/attention/backends/flash_attn.py#L341-L354

since this scheduler may choose a different number of splits than what the graph was captured with

do we have lm-eval accuracy results with full cuda-graphs on?

LucasWilkinson avatar May 07 '25 15:05 LucasWilkinson

I think we may need to disable ahead-of-time scheduling for FA3 when using full cuda-graph:

https://github.com/vllm-project/vllm/blob/1a6af1453d2077832c3d5e8bcd60a5ef6a95e46b/vllm/v1/attention/backends/flash_attn.py#L341-L354

since this scheduler may choose a different number of splits than what the graph was captured with

do we have lm-eval accuracy results with full cuda-graphs on?

Will discuss with you over Slack

chanh avatar May 07 '25 18:05 chanh

I think we may need to disable ahead-of-time scheduling for FA3 when using full cuda-graph: https://github.com/vllm-project/vllm/blob/1a6af1453d2077832c3d5e8bcd60a5ef6a95e46b/vllm/v1/attention/backends/flash_attn.py#L341-L354

since this scheduler may choose a different number of splits than what the graph was captured with do we have lm-eval accuracy results with full cuda-graphs on?

Will discuss with you over Slack

Okay disabled it for now.

chanh avatar May 07 '25 20:05 chanh

I think we may need to disable ahead-of-time scheduling for FA3 when using full cuda-graph: https://github.com/vllm-project/vllm/blob/1a6af1453d2077832c3d5e8bcd60a5ef6a95e46b/vllm/v1/attention/backends/flash_attn.py#L341-L354

since this scheduler may choose a different number of splits than what the graph was captured with do we have lm-eval accuracy results with full cuda-graphs on?

Will discuss with you over Slack

Okay disabled it for now.

lm-eval results

[Current branch, Full CUDA Graph flag enabled, modified lm-eval to pass the compilation_config JSON properly to vLLM]
VLLM_FLASH_ATTN_VERSION=3 VLLM_USE_V1=1 \
lm_eval --model vllm \
  --model_args "pretrained=Qwen/Qwen2-1.5B-Instruct,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.8,data_parallel_size=1,compilation_config={\"full_cuda_graph\": true}" \
  --tasks gsm8k

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.5982|±  |0.0135|
|     |       |strict-match    |     5|exact_match|↑  |0.5898|±  |0.0135|


[Main branch]
VLLM_FLASH_ATTN_VERSION=3 VLLM_USE_V1=1 \
lm_eval --model vllm \
  --model_args "pretrained=Qwen/Qwen2-1.5B-Instruct,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.8,data_parallel_size=1" \

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.5951|±  |0.0135|
|     |       |strict-match    |     5|exact_match|↑  |0.5891|±  |0.0136|

chanh avatar May 07 '25 23:05 chanh

Work in progress:

  1. Investigating changes needed to make this work with Llama4 / local attention

just a heads up @zou3519

What is special about local attention?

renjie0 avatar May 13 '25 04:05 renjie0

@chanh It seems that full cuda graph support outputs garbage on latest main. Do you have any idea?

Juelianqvq avatar May 28 '25 06:05 Juelianqvq

@chanh +1 - it seems like the test was never added to CI (needs to be added manually to .buildkite/test-pipeline.yml). When I run the test locally, the first shape works and all the other shapes output garbage.

ProExpertProg avatar May 29 '25 15:05 ProExpertProg

  1. This only works with V1 + FA3, since FA2 currently is not graphable due to an optimization for GQA.

Hello, I have changed the code to make full Cudagraph capture with FA2. The result shows that FA2 can also work correctly. So, I'm curious that what does it specifically refer to that " FA2 currently is not graphable due to an optimization for GQA"

Lmywl avatar Jun 19 '25 12:06 Lmywl

@WoosukKwon This maybe helpful. Regarding to FA2, FlashInfer (https://github.com/flashinfer-ai/flashinfer/pull/1137) recently merges a PR that implements persistent-style FA2 template. This PR unifies prefill and decode, which supports a single cuda-graph for all batcheds and sequence lengths.

happierpig avatar Jun 19 '25 20:06 happierpig

@chanh https://github.com/vllm-project/vllm/issues/23739, do you have any idea of this problem?

xsank avatar Aug 28 '25 02:08 xsank