sglang icon indicating copy to clipboard operation
sglang copied to clipboard

feat: support flashinfer kernel autotune

Open elvischenv opened this issue 2 months ago • 16 comments

Motivation

Flashinfer MoE requires autotune to select a most performant kernel. This should be done before cudagraph captures.

Modifications

This PR added a kernel warmup stage before the cudagraph captures. In the warmup stage, it will run a dummy forward for the model under the Flashinfer autotune context.

Accuracy Tests

lm_eval --model local-completions --tasks gsm8k --model_args model=openai/gpt-oss-120b,base_url=http://127.0.0.1:18000/v1/completions,max_retries=3,tokenized_requests=False,timeout=1200,max_gen_toks=2048,max_length=8192 --batch_size 2048 --trust_remote_code --limit 0.5

PR:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8394|±  |0.0143|
|     |       |strict-match    |     5|exact_match|↑  |0.6318|±  |0.0188|

main:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8394|±  |0.0143|
|     |       |strict-match    |     5|exact_match|↑  |0.6318|±  |0.0188|

Benchmarking and Profiling

python3 -m sglang.bench_serving --model openai/gpt-oss-120b --host 127.0.0.1 --port 18000 --backend sglang-oai --dataset-name random --random-range-ratio 1 --random-input-len 1024 --random-output-len 1024 --max-concurrency 512 --num-prompts 2560

PR(16% perf improvement):

============ Serving Benchmark Result ============
Backend:                                 sglang-oai
Traffic request rate:                    inf
Max request concurrency:                 512
Successful requests:                     2560
Benchmark duration (s):                  138.33
Total input tokens:                      2621440
Total input text tokens:                 2621440
Total input vision tokens:               0
Total generated tokens:                  2621440
Total generated tokens (retokenized):    2539762
Request throughput (req/s):              18.51
Input token throughput (tok/s):          18950.11
Output token throughput (tok/s):         18950.11
Total token throughput (tok/s):          37900.23
Concurrency:                             510.27
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   27573.10
Median E2E Latency (ms):                 27586.49
---------------Time to First Token----------------
Mean TTFT (ms):                          2658.05
Median TTFT (ms):                        2622.65
P99 TTFT (ms):                           4843.49
---------------Inter-Token Latency----------------
Mean ITL (ms):                           24.43
Median ITL (ms):                         21.66
P95 ITL (ms):                            24.23
P99 ITL (ms):                            179.70
Max ITL (ms):                            4049.93
==================================================

main:

============ Serving Benchmark Result ============
Backend:                                 sglang-oai
Traffic request rate:                    inf
Max request concurrency:                 512
Successful requests:                     2560
Benchmark duration (s):                  161.57
Total input tokens:                      2621440
Total input text tokens:                 2621440
Total input vision tokens:               0
Total generated tokens:                  2621440
Total generated tokens (retokenized):    2539762
Request throughput (req/s):              15.84
Input token throughput (tok/s):          16225.25
Output token throughput (tok/s):         16225.25
Total token throughput (tok/s):          32450.49
Concurrency:                             510.02
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   32188.29
Median E2E Latency (ms):                 32219.19
---------------Time to First Token----------------
Mean TTFT (ms):                          2632.67
Median TTFT (ms):                        2608.20
P99 TTFT (ms):                           4823.79
---------------Inter-Token Latency----------------
Mean ITL (ms):                           28.98
Median ITL (ms):                         26.49
P95 ITL (ms):                            28.89
P99 ITL (ms):                            154.62
Max ITL (ms):                            4009.70
==================================================

Checklist

elvischenv avatar Oct 29 '25 01:10 elvischenv

[!WARNING] You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

gemini-code-assist[bot] avatar Oct 29 '25 01:10 gemini-code-assist[bot]

Great work! May I ask how long a single tuning run takes now? Is there a switch to control whether the kernel is tuned?

FlamingoPg avatar Oct 29 '25 05:10 FlamingoPg

@FlamingoPg

May I ask how long a single tuning run takes now?

For B200+gpt-oss-120b, it takes about 1 min from my local test:

[2025-10-29 02:42:01] Running FlashInfer autotune...
2025-10-29 02:42:01,952 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ...
2025-10-29 02:42:57,165 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends
[2025-10-29 02:42:57] FlashInfer autotune completed.

Is there a switch to control whether the kernel is tuned?

The current design from flashinfer is just running one dummy forward iteration on max batch size with flashinfer context, if there is any flashinfer apis called in the iteration(e.g. trtllm_fp4_block_scale_moe) then flashinfer will do autotune for these apis to find the best kernel.

elvischenv avatar Oct 29 '25 07:10 elvischenv

@FlamingoPg Could you review this PR? Thanks!

nvpohanh avatar Nov 05 '25 08:11 nvpohanh

@FlamingoPg Could you review this PR? Thanks!

Looks fine, let's wait for the CI

FlamingoPg avatar Nov 06 '25 03:11 FlamingoPg

@FlamingoPg There are quite a few CI failures. Are all of them caused by this PR? Or are some of them known failing issues?

nvpohanh avatar Nov 07 '25 01:11 nvpohanh

I will help you rerun failed jobs

FlamingoPg avatar Nov 10 '25 09:11 FlamingoPg

Hi @FlamingoPg, It seems that CI failures are not related to my PR. Could you help confirm? Thanks!

elvischenv avatar Nov 13 '25 06:11 elvischenv

Hi @FlamingoPg, It seems that CI failures are not related to my PR. Could you help confirm? Thanks!

Sure

FlamingoPg avatar Nov 13 '25 06:11 FlamingoPg

@FlamingoPg Could you check the remaining failures? thanks!

nvpohanh avatar Nov 14 '25 02:11 nvpohanh

@FlamingoPg Could you check the remaining failures? thanks!

Sure

FlamingoPg avatar Nov 14 '25 09:11 FlamingoPg

[2025-11-18 13:28:40 TP0] Scheduler hit an exception: Traceback (most recent call last):
  File "/public_sglang_ci/runner-l1a-gpu-4567/_work/sglang/sglang/python/sglang/srt/managers/scheduler.py", line 2712, in run_scheduler_process
    scheduler = Scheduler(
  File "/public_sglang_ci/runner-l1a-gpu-4567/_work/sglang/sglang/python/sglang/srt/managers/scheduler.py", line 312, in __init__
    self.tp_worker = TpModelWorker(
  File "/public_sglang_ci/runner-l1a-gpu-4567/_work/sglang/sglang/python/sglang/srt/managers/tp_worker.py", line 237, in __init__
    self._model_runner = ModelRunner(
  File "/public_sglang_ci/runner-l1a-gpu-4567/_work/sglang/sglang/python/sglang/srt/model_executor/model_runner.py", line 340, in __init__
    self.initialize(min_per_gpu_memory)
  File "/public_sglang_ci/runner-l1a-gpu-4567/_work/sglang/sglang/python/sglang/srt/model_executor/model_runner.py", line 507, in initialize
    self.kernel_warmup()
  File "/public_sglang_ci/runner-l1a-gpu-4567/_work/sglang/sglang/python/sglang/srt/model_executor/model_runner.py", line 2033, in kernel_warmup
    self._flashinfer_autotune()
  File "/public_sglang_ci/runner-l1a-gpu-4567/_work/sglang/sglang/python/sglang/srt/model_executor/model_runner.py", line 2061, in _flashinfer_autotune
    self._dummy_run(batch_size=self.req_to_token_pool.size)
  File "/public_sglang_ci/runner-l1a-gpu-4567/_work/sglang/sglang/python/sglang/srt/model_executor/model_runner.py", line 2298, in _dummy_run
    self.attn_backend.init_forward_metadata(forward_batch)
  File "/public_sglang_ci/runner-l1a-gpu-4567/_work/sglang/sglang/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py", line 838, in init_forward_metadata
    attn_backend.init_forward_metadata(forward_batch)
  File "/public_sglang_ci/runner-l1a-gpu-4567/_work/sglang/sglang/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py", line 123, in init_forward_metadata
    self.forward_metadata = self._forward_metadata(forward_batch)
  File "/public_sglang_ci/runner-l1a-gpu-4567/_work/sglang/sglang/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py", line 99, in _forward_metadata
    retrieve_parent_token = torch.empty_like(retrieve_next_token)
TypeError: empty_like(): argument 'input' (position 1) must be Tensor, not NoneType

[2025-11-18 13:28:40] Received sigquit from a child process. It usually means the child failed.

https://github.com/sgl-project/sglang/actions/runs/19458190143/job/55706006309?pr=12306

@elvischenv do you think this failure is related to this PR?

nvpohanh avatar Nov 19 '25 02:11 nvpohanh

https://github.com/sgl-project/sglang/actions/runs/19458190143/job/55706006309?pr=12306 @elvischenv do you think this failure is related to this PR?

Should be related to a PR that merged 2 weeks ago: #11133. Pushed a fix and let see if it works.

elvischenv avatar Nov 19 '25 05:11 elvischenv

The two GPU pipeline failures seem to be caused by OOM, not related to this PR. @FlamingoPg could you re-run these two pipelines? Thanks!

nvpohanh avatar Nov 20 '25 01:11 nvpohanh

Does auto-tuning also work well for low-latency cases? Or could we control this feature using server parameters?

What does "work well" mean? At least it should not cause performance regression, right?

But I agree that we can add a server flag to allow users to disable autotuning if they want.

nvpohanh avatar Nov 21 '25 05:11 nvpohanh

Could you provide your launch commands? I will try to reproduce the results later.

Qiaolin-Yu avatar Nov 21 '25 23:11 Qiaolin-Yu

Waiting for the CI to pass.

Qiaolin-Yu avatar Nov 26 '25 04:11 Qiaolin-Yu

Hi @Qiaolin-Yu @FlamingoPg @Fridge003, could you help us to merge this PR? The CI failures are all unrelated. Thanks!

elvischenv avatar Nov 26 '25 07:11 elvischenv

Hi @Qiaolin-Yu @FlamingoPg @Fridge003, could you help us to merge this PR? The CI failures are all unrelated. Thanks!

@Kangyan-Zhou will help to merge this after all Nvidia CI pass.

Qiaolin-Yu avatar Nov 27 '25 01:11 Qiaolin-Yu

Test failures:

Capturing batches (bs=256 avail_mem=17.16 GB):   0%|          | 0/1 [00:26<?, ?it/s]
[2025-11-27 06:51:25 DP0 TP0 EP0] Registering 0 cuda graph addresses
E[2025-11-27 06:59:56] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2025-11-27 06:59:56] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.

and timeout at:

Capturing batches (bs=64 avail_mem=15.67 GB):   0%|          | 0/1 [00:26<?, ?it/s]

why does capturing cuda graphs take that long?

  • test_quantization.py: Mixtral-8x7B-Instruct-v0.1-AWQ-INT4 has slightly lower accuracy (0.612) than target 0.62

nvpohanh avatar Nov 28 '25 02:11 nvpohanh

Test failures:

Capturing batches (bs=256 avail_mem=17.16 GB):   0%|          | 0/1 [00:26<?, ?it/s]
[2025-11-27 06:51:25 DP0 TP0 EP0] Registering 0 cuda graph addresses
E[2025-11-27 06:59:56] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2025-11-27 06:59:56] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.

and timeout at:

Capturing batches (bs=64 avail_mem=15.67 GB):   0%|          | 0/1 [00:26<?, ?it/s]

why does capturing cuda graphs take that long?

  • test_quantization.py: Mixtral-8x7B-Instruct-v0.1-AWQ-INT4 has slightly lower accuracy (0.612) than target 0.62

We've fixed a few test failure issue on main due to a number of issues recently. Sry for the delay, retrying now

Kangyan-Zhou avatar Nov 28 '25 02:11 Kangyan-Zhou