vllm icon indicating copy to clipboard operation
vllm copied to clipboard

[Kernel][ROCM] Upstream prefix prefill speed up for vLLM V1

Open maleksan85 opened this issue 9 months ago • 25 comments

Speed up prefix prefill with vLLM V1 on AMG GPUs

Improvements:

  1. Vectorization in the context loop (most complex one as k cache shape is very specific)
  2. Refactoring for online softmax computation
  3. Refactoring to the kernel so autotune might select the best configs per shape
  4. Plus adding new spectrum of unrolling/staging in autotuner

More details on triton kernel tunning: https://rocm.docs.amd.com/en/docs-6.1.1/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html

see last comments

maleksan85 avatar Feb 14 '25 19:02 maleksan85

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

github-actions[bot] avatar Feb 14 '25 19:02 github-actions[bot]

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @maleksan85.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Feb 14 '25 19:02 mergify[bot]

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @maleksan85.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Feb 24 '25 22:02 mergify[bot]

Just made an initial pass. Not a final review. Could you describe the performance improvements that you made? Also, unless the improvements you made regress the performance of the existing kernel, I'd like to avoid creating a whole new kernel.

In the meantime, I'll try to repro your results locally and post back

Added. Please see updated description as well as future possible improvements.

maleksan85 avatar Feb 27 '25 17:02 maleksan85

To validate the correctness of this kernel can you run VLLM_USE_V1=1 lm_eval --model vllm --model_args pretrained=meta-llama/Llama-3.1-8B-Instruct --tasks gsm8k --num_fewshot 5 --batch_size auto --limit 500

SageMoore avatar Feb 27 '25 17:02 SageMoore

To validate the correctness of this kernel can you run VLLM_USE_V1=1 lm_eval --model vllm --model_args pretrained=meta-llama/Llama-3.1-8B-Instruct --tasks gsm8k --num_fewshot 5 --batch_size auto --limit 500

new kernel slightly off from what it was because of online softmax computation (divider is accumulated over all iterations and then applied in the end of kernel) 2025-03-03:21:07:16,982 INFO [evaluation_tracker.py:269] Output path not provided, skipping saving results aggregated vllm (pretrained=/data/models/Llama-3.1-8B-Instruct), gen_kwargs: (None), limit: 500.0, num_fewshot: 5, batch_size: auto

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.622 ± 0.0217
strict-match 5 exact_match 0.614 ± 0.0218

original: 2025-03-03:21:13:18,136 INFO [evaluation_tracker.py:269] Output path not provided, skipping saving results aggregated vllm (pretrained=/data/models/Llama-3.1-8B-Instruct), gen_kwargs: (None), limit: 500.0, num_fewshot: 5, batch_size: auto

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.792 ± 0.0182
strict-match 5 exact_match 0.778 ± 0.0186

maleksan85 avatar Mar 04 '25 06:03 maleksan85

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @maleksan85.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Mar 04 '25 17:03 mergify[bot]

@SageMoore

Restored softmax computation as it was in original kernel:

2025-03-04:17:26:25,991 INFO [evaluation_tracker.py:269] Output path not provided, skipping saving results aggregated vllm (pretrained=/data/models/Llama-3.1-8B-Instruct), gen_kwargs: (None), limit: 500.0, num_fewshot: 5, batch_size: auto

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.80 ± 0.0179
strict-match 5 exact_match 0.78 ± 0.0185
--input-len 1 --output-len 32 --batch-size 16 

Opt softmax: Avg latency: 0.2485975159254546 seconds (from 0.402870576 seconds)
Orig softmax: Avg latency: 0.2614329825155437
H100 0.252115006 seconds

 --input-len 1 --output-len 32 --batch-size 64

Opt softmax: Avg latency: 0.36089399384800347 seconds (from 0.538930605 seconds)
Orig softmax: Avg latency: 0.3783661962952465 seconds
H100 0.299436951 seconds

--input-len 1 --output-len 32 --batch-size 128 

Opt softmax: Avg latency: 0.5479010133945849 seconds (from 0.817037284 seconds)
Orig softmax: Avg latency: 0.5533470625523478 seconds
H100 0.380702962 seconds

--input-len 512 --output-len 1 --batch-size 1

Opt softmax: Avg latency: 0.009134410376039645 seconds (from 0.013344106 seconds)
Orig softmax: Avg latency: 0.008717456289256612 seconds
H100 0.008099335 seconds

--input-len 2048 --output-len 1 --batch-size 1

Opt softmax: Avg latency: 0.013733387808315456 seconds (from 0.023251219 seconds)
Orig softmax: Avg latency: 0.013126811214412252 seconds
H100 0.008553044 seconds

maleksan85 avatar Mar 04 '25 17:03 maleksan85

with the mix of this one and mine, getting this. On a second run when triton cache is populated (cache is generated at first run).

============ Serving Benchmark Result ============
Successful requests:                     1000
Benchmark duration (s):                  20.06
Total input tokens:                      215196
Total generated tokens:                  197921
Request throughput (req/s):              49.84
Output token throughput (tok/s):         9864.60
Total Token throughput (tok/s):          20590.21
---------------Time to First Token----------------
Mean TTFT (ms):                          1097.03
Median TTFT (ms):                        1106.22
P99 TTFT (ms):                           1445.13
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          43.86
Median TPOT (ms):                        40.61
P99 TPOT (ms):                           86.60
---------------Inter-token Latency----------------
Mean ITL (ms):                           33.51
Median ITL (ms):                         33.40
P99 ITL (ms):                            84.48
==================================================

and

2025-03-10:23:46:31,843 INFO [evaluation_tracker.py:269] Output path not provided, skipping saving results aggregated vllm (pretrained=/data/models/Llama-3.1-8B-Instruct), gen_kwargs: (None), limit: 500.0, num_fewshot: 5, batch_size: auto

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.792 ± 0.0182
strict-match 5 exact_match 0.778 ± 0.0186

cc @SageMoore and @comaniac please help to land it

maleksan85 avatar Mar 10 '25 23:03 maleksan85

@maleksan85 Which GPU are the above results on btw? If an H100, the throughput is really close to what I've been measuring with FlashAttention...

tdoublep avatar Mar 11 '25 19:03 tdoublep

@maleksan85 Which GPU are the above results on btw? If an H100, the throughput is really close to what I've been measuring with FlashAttention...

MI300

maleksan85 avatar Mar 11 '25 21:03 maleksan85

@maleksan85 Are you results above using any specific Triton commit, or do they use Triton 3.2.0 as distributed via PyPI?

tdoublep avatar Mar 14 '25 16:03 tdoublep

@maleksan85 Are you results above using any specific Triton commit, or do they use Triton 3.2.0 as distributed via PyPI?

I'm using "default" one that goes with rocm docker file: Version: 3.2.0+gite5be006a

maleksan85 avatar Mar 14 '25 17:03 maleksan85

I've been testing these changes on H100 and here is what I find:

Deploy server from this branch:

export TRITON_PRINT_AUTOTUNING=1
VLLM_USE_V1=1 vllm serve /models/llama3.1-8b/instruct/ \
    --disable-log-requests 

Benchmark command:

VLLM_USE_V1=1 python benchmarks/benchmark_serving.py \
    --model /models/llama3.1-8b/instruct/ \
    --dataset-name sharegpt \
    --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json \
    --ignore-eos

Auto-tuning result:

Triton autotuning for function _fwd_kernel finished after 7.88s; best config selected: BLOCK_M: 32, BLOCK_N: 64, num_unroll_cache: 2, num_unroll_request: 1, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None;

First run (this branch):

============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  23.29     
Total input tokens:                      215196    
Total generated tokens:                  198343    
Request throughput (req/s):              42.94     
Output token throughput (tok/s):         8516.11   
Total Token throughput (tok/s):          17755.83  
---------------Time to First Token----------------
Mean TTFT (ms):                          3961.21   
Median TTFT (ms):                        3756.92   
P99 TTFT (ms):                           7469.11   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          98.72     
Median TPOT (ms):                        53.60     
P99 TPOT (ms):                           262.90    
---------------Inter-token Latency----------------
Mean ITL (ms):                           41.85     
Median ITL (ms):                         24.40     
P99 ITL (ms):                            267.79    
==================================================

Second run (this branch):

============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  22.34     
Total input tokens:                      215196    
Total generated tokens:                  198343    
Request throughput (req/s):              44.75     
Output token throughput (tok/s):         8876.78   
Total Token throughput (tok/s):          18507.80  
---------------Time to First Token----------------
Mean TTFT (ms):                          3374.36   
Median TTFT (ms):                        3274.86   
P99 TTFT (ms):                           6480.41   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          104.11    
Median TPOT (ms):                        51.26     
P99 TPOT (ms):                           365.52    
---------------Inter-token Latency----------------
Mean ITL (ms):                           40.91     
Median ITL (ms):                         25.30     
P99 ITL (ms):                            333.10    
==================================================

I believe the improvement performance between first and second run is due to increase in prefix cache hit rate (from 0.9% to 41.5%) rather than any auto-tuning effects (which are already hidden by initial test prompt).

If I compare to the results from main, it looks like the changes from this PR actually make the performance quite a bit worse on H100 unfortunately.

First run (main)

============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  22.39     
Total input tokens:                      215196    
Total generated tokens:                  198343    
Request throughput (req/s):              44.66     
Output token throughput (tok/s):         8858.87   
Total Token throughput (tok/s):          18470.46  
---------------Time to First Token----------------
Mean TTFT (ms):                          3533.85   
Median TTFT (ms):                        3402.28   
P99 TTFT (ms):                           6724.64   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          90.01     
Median TPOT (ms):                        50.50     
P99 TPOT (ms):                           233.03    
---------------Inter-token Latency----------------
Mean ITL (ms):                           39.77     
Median ITL (ms):                         25.19     
P99 ITL (ms):                            236.98    
==================================================

Second run (main)

============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  20.98     
Total input tokens:                      215196    
Total generated tokens:                  198343    
Request throughput (req/s):              47.67     
Output token throughput (tok/s):         9454.58   
Total Token throughput (tok/s):          19712.51  
---------------Time to First Token----------------
Mean TTFT (ms):                          2519.86   
Median TTFT (ms):                        2201.87   
P99 TTFT (ms):                           5135.88   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          92.62     
Median TPOT (ms):                        48.39     
P99 TPOT (ms):                           311.25    
---------------Inter-token Latency----------------
Mean ITL (ms):                           38.49     
Median ITL (ms):                         24.44     
P99 ITL (ms):                            275.93    
==================================================

@maleksan85 can you please confirm whether the numbers you are reporting on MI300x are from first run or second run (where APC effect is significantly increased)?

tdoublep avatar Mar 19 '25 19:03 tdoublep

thank you @tdoublep will follow up on that!

maleksan85 avatar Mar 19 '25 21:03 maleksan85

(I asked @russellb to disable auto-merge until we get to the bottom of the performance numbers here)

ProExpertProg avatar Mar 20 '25 21:03 ProExpertProg

I revved torch up to 2.6 on ROCm (MI300X), which has triton 3.2, and got the following result with

VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.1-8B-Instruct

and

 VLLM_USE_V1=1 python benchmarks/benchmark_serving.py --model  meta-llama/Llama-3.1-8B-Instruct --dataset-name sharegpt --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json --ignore-eos

Baseline

============ Serving Benchmark Result ============
Successful requests:                     990
Benchmark duration (s):                  30.60
Total input tokens:                      212354
Total generated tokens:                  196200
Request throughput (req/s):              32.36
Output token throughput (tok/s):         6412.39
Total Token throughput (tok/s):          13352.73
---------------Time to First Token----------------
Mean TTFT (ms):                          5818.41
Median TTFT (ms):                        5947.09
P99 TTFT (ms):                           11813.52
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          66.83
Median TPOT (ms):                        65.67
P99 TPOT (ms):                           114.42
---------------Inter-token Latency----------------
Mean ITL (ms):                           47.03
Median ITL (ms):                         37.90
P99 ITL (ms):                            120.12
==================================================

This PR

============ Serving Benchmark Result ============
Successful requests:                     990
Benchmark duration (s):                  30.41
Total input tokens:                      212354
Total generated tokens:                  196200
Request throughput (req/s):              32.56
Output token throughput (tok/s):         6452.11
Total Token throughput (tok/s):          13435.45
---------------Time to First Token----------------
Mean TTFT (ms):                          6106.38
Median TTFT (ms):                        5780.37
P99 TTFT (ms):                           11939.41
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          64.95
Median TPOT (ms):                        62.68
P99 TPOT (ms):                           102.92
---------------Inter-token Latency----------------
Mean ITL (ms):                           45.97
Median ITL (ms):                         38.01
P99 ITL (ms):                            104.37
==================================================

SageMoore avatar Mar 20 '25 21:03 SageMoore

@SageMoore @tdoublep thank you for benchmarks. Yes. perf of this PR is worse than original for small contexts like 64. Investigating why it is so.

maleksan85 avatar Mar 24 '25 17:03 maleksan85

Thank you all reviewers for looking at this PR. As per my investigation yesterday using rocprof at average this kernel (with slightly improved perf, will update version shortly) gives from 0-20% boost. However the downsides on vllm serve is that every time it sees new triton.next_power_of_2(max_input_len), triton starts autotuning to select the best config. That is why it requires to run few times test to start seeing perf difference. On top of that the kernel in PR is called only in case prefills which is 30% or smaller execution time, thus overall e2e perf with new kernel might slightly better as present by tests above.

examples from rpd trace (_my_fwd_kernel is kernels from this PR and _fwd_kernel is original, running them one after another): {66F5EB4E-90A3-448C-AA81-7DA56268BC11}

{4E81E414-F3AC-4562-A6E8-D22125F522B0}

{5014C6C7-87BE-4485-9FAB-782B1532EE1B}

And the most important thing. Kernel shows 2-3x perf boost over original only in case when we have long chunked prefills. Long mean 256 or more. So on benchmarks I see this:

my_prefix_prefill.py::test_contexted_kv_attention[ block size: 32 batch size: 32 max ctx len: 4096 seq len: 1024 head size: 128 kv heads: 8 query heads: 32 PR kernel time: 6.27 ms Original kernel time: 19.27 ms

it is almost 3x on my benchmark script.

similar to above but max seq len: 4096 PR kernel time: 31.81 ms Original kernel time: 81.84 ms

Can share benchmark script if anyone is interested to play with.

PS in the test that runs by python vllm/benchmarks/benchmark_serving.py --model /data/models/Llama-3.1-8B-Instruct --dataset-name sharegpt --dataset-path datasets/ShareGPT/ShareGPT_V3_unfiltered_cleaned_split.json there is usually one (for some reason) prefill that have something in context (aka cache). So no work for tuned, over original kernel, to work with cache first loop.

maleksan85 avatar Mar 26 '25 19:03 maleksan85

However the downsides on vllm serve is that every time it sees new triton.next_power_of_2(max_input_len), triton starts autotuning to select the best config.

Hi @maleksan85, yes that is indeed the biggest disadvantage of the autotuner in triton upstream. That's why we developed triton-dejavu, a plug-in replacement for the autotuner that saves and restores the autotuner runs: https://github.com/IBM/triton-dejavu

(I planned to take a look at your proposed changes and combine it with this dejavu mechanism earlier this month, but I didn't had the time yet.)

We use this for example for the non-chunked prefill kernel in our triton-only plugin: https://github.com/foundation-model-stack/vllm-triton-backend/blob/d11132b6558c3f4be725bbfb918afde31f72f532/ibm-triton-lib/ibm_triton_lib/kernels/triton_flash_attention.py#L735

bringlein avatar Mar 27 '25 09:03 bringlein

for isl in 1000 5000 10000; do \
    echo $isl; \
    rm -rf ~/.triton/cache; rm -rf /root/.cache/vllm; HIP_VISIBLE_DEVICES=4,5,6,7 VLLM_USE_V1=1 python benchmarks/benchmark_throughput.py -tp 4 --model /data/models/Llama-3.1-70B-Instruct --num_prompts 1000 --no-enable-prefix-caching --max-model-len 32768 --max-num-batched-tokens 32768 --input-len $isl --output-len 100 2>&1 | grep Throughput; \
    rm -rf ~/.triton/cache; rm -rf /root/.cache/vllm; HIP_VISIBLE_DEVICES=4,5,6,7 VLLM_USE_V1=1 python benchmarks/benchmark_throughput.py -tp 4 --model /data/models/Llama-3.1-70B-Instruct --num_prompts 1000 --no-enable-prefix-caching --max-model-len 32768 --max-num-batched-tokens 32768 --input-len $isl --output-len 100 2>&1 | grep Throughput; \
    rm -rf ~/.triton/cache; rm -rf /root/.cache/vllm; HIP_VISIBLE_DEVICES=4,5,6,7 VLLM_USE_V1=1 python benchmarks/benchmark_throughput.py -tp 4 --model /data/models/Llama-3.1-70B-Instruct --num_prompts 1000 --no-enable-prefix-caching --max-model-len 32768 --max-num-batched-tokens 32768 --input-len $isl --output-len 100 2>&1 | grep Throughput; \
done

with PR (ISL/OSL) 1000/100 (1.02x) Throughput: 7.67 requests/s, 8796.44 total tokens/s, 766.56 output tokens/s Throughput: 7.67 requests/s, 8794.59 total tokens/s, 767.07 output tokens/s Throughput: 7.67 requests/s, 8793.23 total tokens/s, 766.70 output tokens/s 5000/100 (1.15x) Throughput: 1.54 requests/s, 8197.98 total tokens/s, 153.83 output tokens/s Throughput: 1.55 requests/s, 8240.60 total tokens/s, 154.53 output tokens/s Throughput: 1.54 requests/s, 8236.00 total tokens/s, 154.47 output tokens/s 10000/100 (1.19x) Throughput: 0.75 requests/s, 7894.16 total tokens/s, 74.76 output tokens/s Throughput: 0.74 requests/s, 7816.95 total tokens/s, 74.00 output tokens/s

Upstream (ISL/OSL) 1000/100 Throughput: 7.48 requests/s, 8588.32 total tokens/s, 748.50 output tokens/s Throughput: 7.46 requests/s, 8561.20 total tokens/s, 746.17 output tokens/s Throughput: 7.52 requests/s, 8627.33 total tokens/s, 751.56 output tokens/s 5000/100 Throughput: 1.33 requests/s, 7078.32 total tokens/s, 132.78 output tokens/s Throughput: 1.31 requests/s, 6977.49 total tokens/s, 130.97 output tokens/s Throughput: 1.34 requests/s, 7147.87 total tokens/s, 134.04 output tokens/s 10000/100 Throughput: 0.63 requests/s, 6599.19 total tokens/s, 62.50 output tokens/s Throughput: 0.64 requests/s, 6733.11 total tokens/s, 63.80 output tokens/s Throughput: 0.62 requests/s, 6561.00 total tokens/s, 62.13 output tokens/s

maleksan85 avatar Apr 08 '25 14:04 maleksan85

for isl in 1000 5000 10000; do 
   echo $isl; 
   rm -rf ~/.triton/cache; rm -rf /root/.cache/vllm; HIP_VISIBLE_DEVICES=4,5,6,7 VLLM_USE_V1=1 python benchmarks/benchmark_latency.py -tp 4 --model /data/models/Llama-3.1-70B-Instruct --input-len $isl --output-len 1 --batch-size 64 2>&1 | grep "Avg latency";     
   rm -rf ~/.triton/cache; rm -rf /root/.cache/vllm; HIP_VISIBLE_DEVICES=4,5,6,7 VLLM_USE_V1=1 python benchmarks/benchmark_latency.py -tp 4 --model /data/models/Llama-3.1-70B-Instruct --input-len $isl --output-len 1 --batch-size 64 2>&1 | grep "Avg latency";     
   rm -rf ~/.triton/cache; rm -rf /root/.cache/vllm; HIP_VISIBLE_DEVICES=4,5,6,7 VLLM_USE_V1=1 python benchmarks/benchmark_latency.py -tp 4 --model /data/models/Llama-3.1-70B-Instruct --input-len $isl --output-len 1 --batch-size 64 2>&1 | grep "Avg latency"; 
done

with PR: 1000 (1.26x) Avg latency: 0.15809809761121868 seconds Avg latency: 0.15592556800693275 seconds Avg latency: 0.15606141282866398 seconds 5000 (1.47x) Avg latency: 0.35732791749760506 seconds Avg latency: 0.3611876289981107 seconds Avg latency: 0.36364591543873154 seconds 10000 (1.47x) Avg latency: 0.701720359083265 seconds Avg latency: 0.7041152139815191 seconds Avg latency: 0.7023595116411646 seconds

Upstream: 1000 Avg latency: 0.1936436322517693 seconds Avg latency: 0.19345509118090073 seconds Avg latency: 0.19940307652577757 seconds 5000 Avg latency: 0.5265994310689469 seconds Avg latency: 0.543060151146104 seconds Avg latency: 0.5372202017654976 seconds 10000 Avg latency: 1.0412443388563892 seconds Avg latency: 1.0152872786857188 seconds Avg latency: 1.0336426993831993 seconds

maleksan85 avatar Apr 08 '25 16:04 maleksan85

HIP_VISIBLE_DEVICES=6 VLLM_ENABLE_V1_MULTIPROCESSING=0 VLLM_USE_V1=1 lm_eval --model vllm --model_args pretrained=/data/models/Llama-3.1-8B-Instruct --tasks gsm8k --num_fewshot 5 --batch_size auto -
-limit 500

2025-04-08:18:10:02,846 INFO [lm_eval.loggers.evaluation_tracker:272] Output path not provided, skipping saving results aggregated vllm (pretrained=/data/models/Llama-3.1-8B-Instruct), gen_kwargs: (None), limit: 500.0, num_fewshot: 5, batch_size: auto

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.808 ± 0.0176
strict-match 5 exact_match 0.782 ± 0.0185

maleksan85 avatar Apr 08 '25 18:04 maleksan85

python3 benchmarks/benchmark_serving.py
--backend vllm
--model /data/models/Llama-3.1-70B-Instruct
--dataset-name random
--random-input-len 10000
--random-output-len 100
--num-prompts 300
--seed 42
--ignore-eos
--percentile-metrics "ttft,tpot,itl,e2el"

PR (like 20% gain)

============ Serving Benchmark Result ============
Successful requests:                     300
Benchmark duration (s):                  409.78
Total input tokens:                      3000000
Total generated tokens:                  30000
Request throughput (req/s):              0.73
Output token throughput (tok/s):         73.21
Total Token throughput (tok/s):          7394.28
---------------Time to First Token----------------
Mean TTFT (ms):                          205042.73
Median TTFT (ms):                        203406.19
P99 TTFT (ms):                           400609.81
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          1610.15
Median TPOT (ms):                        2027.83
P99 TPOT (ms):                           2239.19
---------------Inter-token Latency----------------
Mean ITL (ms):                           1610.15
Median ITL (ms):                         80.56
P99 ITL (ms):                            5252.32
----------------End-to-end Latency----------------
Mean E2EL (ms):                          364447.21
Median E2EL (ms):                        404161.34
P99 E2EL (ms):                           409588.24
==================================================

Upstream

============ Serving Benchmark Result ============
Successful requests:                     300
Benchmark duration (s):                  498.15
Total input tokens:                      3000000
Total generated tokens:                  30000
Request throughput (req/s):              0.60
Output token throughput (tok/s):         60.22
Total Token throughput (tok/s):          6082.51
---------------Time to First Token----------------
Mean TTFT (ms):                          249095.71
Median TTFT (ms):                        248711.87
P99 TTFT (ms):                           488484.85
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          1957.47
Median TPOT (ms):                        2462.50
P99 TPOT (ms):                           2732.60
---------------Inter-token Latency----------------
Mean ITL (ms):                           1957.47
Median ITL (ms):                         80.32
P99 ITL (ms):                            8005.81
----------------End-to-end Latency----------------
Mean E2EL (ms):                          442885.68
Median E2EL (ms):                        492500.58
P99 E2EL (ms):                           497952.19
==================================================

maleksan85 avatar Apr 08 '25 22:04 maleksan85

python3 benchmarks/benchmark_serving.py
--backend vllm
--model /data/models/Llama-3.1-70B-Instruct
--dataset-name random
--random-input-len 5000
--random-output-len 100
--num-prompts 500
--seed 42
--ignore-eos
--percentile-metrics "ttft,tpot,itl,e2el"

PR (10% gain)

============ Serving Benchmark Result ============
Successful requests:                     500
Benchmark duration (s):                  319.37
Total input tokens:                      2500000
Total generated tokens:                  50000
Request throughput (req/s):              1.57
Output token throughput (tok/s):         156.56
Total Token throughput (tok/s):          7984.50
---------------Time to First Token----------------
Mean TTFT (ms):                          155485.39
Median TTFT (ms):                        149836.40
P99 TTFT (ms):                           310684.27
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          1219.18
Median TPOT (ms):                        1556.81
P99 TPOT (ms):                           1629.28
---------------Inter-token Latency----------------
Mean ITL (ms):                           1219.18
Median ITL (ms):                         77.67
P99 ITL (ms):                            4265.61
----------------End-to-end Latency----------------
Mean E2EL (ms):                          276184.44
Median E2EL (ms):                        310784.82
P99 E2EL (ms):                           319205.24
==================================================

Upstream

============ Serving Benchmark Result ============
Successful requests:                     500
Benchmark duration (s):                  355.99
Total input tokens:                      2500000
Total generated tokens:                  50000
Request throughput (req/s):              1.40
Output token throughput (tok/s):         140.45
Total Token throughput (tok/s):          7163.04
---------------Time to First Token----------------
Mean TTFT (ms):                          172121.19
Median TTFT (ms):                        162339.60
P99 TTFT (ms):                           349045.74
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          1369.76
Median TPOT (ms):                        1699.35
P99 TPOT (ms):                           1892.04
---------------Inter-token Latency----------------
Mean ITL (ms):                           1369.76
Median ITL (ms):                         78.00
P99 ITL (ms):                            6167.44
----------------End-to-end Latency----------------
Mean E2EL (ms):                          307727.51
Median E2EL (ms):                        349138.54
P99 E2EL (ms):                           355831.83
==================================================

maleksan85 avatar Apr 08 '25 22:04 maleksan85