sglang icon indicating copy to clipboard operation
sglang copied to clipboard

[Feat] Enable PDL automatically on Hopper architecture

Open PopSoda2002 opened this issue 8 months ago • 4 comments

Motivation

In previous versions, FlashInfer 0.2.5 supports norm's PDL, but currently, norm's PDL is disabled by default. This PR modifies the code to enable it automatically on Hopper architecture.

Modifications

  1. Add is_hopper_arch() utility function to detect Hopper architecture (compute capability >= 9.0)
  2. Modify rmsnorm, fused_add_rmsnorm, gemma_rmsnorm and gemma_fused_add_rmsnorm functions to auto-enable PDL on Hopper
  3. Update documentation to reflect these changes

This implementation automatically enables PDL optimization on Hopper GPUs while maintaining backward compatibility by allowing explicit override through the parameter.

Checklist

  • [x] Format your code according to the Code Formatting with Pre-Commit.
  • [ ] Add unit tests as outlined in the Running Unit Tests.
  • [ ] Update documentation / docstrings / example tutorials as needed, according to Writing Documentation.
  • [ ] Provide throughput / latency benchmark results and accuracy evaluation results as needed, according to Benchmark and Profiling and Accuracy Results.
  • [ ] For reviewers: If you haven't made any contributions to this PR and are only assisting with merging the main branch, please remove yourself as a co-author when merging the PR.
  • [ ] Please feel free to join our Slack channel at https://slack.sglang.ai to discuss your PR.

PopSoda2002 avatar May 02 '25 15:05 PopSoda2002

Also please provide the performance benchmark after this enhancement

zhyncs avatar May 02 '25 21:05 zhyncs

Also please provide the performance benchmark after this enhancement

Yes, there is another guy who is testing the performance!

PopSoda2002 avatar May 02 '25 23:05 PopSoda2002

Also please provide the performance benchmark after this enhancement

Yes, there is another guy who is testing the performance!

Hi, how's the progress on the benchmark?

hebiao064 avatar May 12 '25 05:05 hebiao064

Still working bro

Also please provide the performance benchmark after this enhancement

Yes, there is another guy who is testing the performance!

Hi, how's the progress on the benchmark?

PopSoda2002 avatar May 12 '25 09:05 PopSoda2002

Here is my benchmark for testing(test on H100): command:

python3 -m sglang.bench_one_batch --model-path meta-llama/Llama-3.1-8B-Instruct --attention-backend fa3 --batch 16 --input-len 1024 --output-len 10

Before this PR: image

After: image

Thanks @Fridge003 for helping!

PopSoda2002 avatar May 21 '25 00:05 PopSoda2002

batch_size hidden_size dtype w/o pdl w/ pdl
1 111 torch.float16 12.224000 9.632000
1 111 torch.bfloat16 10.976000 11.040000
1 500 torch.float16 11.008000 10.144000
1 500 torch.bfloat16 11.424000 9.632000
1 1024 torch.float16 11.360000 9.696000
1 1024 torch.bfloat16 11.392000 9.984000
1 3072 torch.float16 12.544000 10.304000
1 3072 torch.bfloat16 11.136000 11.264000
1 3584 torch.float16 11.616000 10.016000
1 3584 torch.bfloat16 11.296000 10.368000
1 4096 torch.float16 11.648000 11.456000
1 4096 torch.bfloat16 11.680000 11.456000
1 8192 torch.float16 13.248000 10.880000
1 8192 torch.bfloat16 12.128000 11.936000
1 16384 torch.float16 12.864000 11.904000
1 16384 torch.bfloat16 13.216000 11.584000
19 111 torch.float16 11.776000 11.584000
19 111 torch.bfloat16 12.864000 11.584000
19 500 torch.float16 12.544000 9.920000
19 500 torch.bfloat16 11.584000 9.920000
19 1024 torch.float16 11.776000 10.496000
19 1024 torch.bfloat16 12.672000 10.048000
19 3072 torch.float16 12.832000 11.520000
19 3072 torch.bfloat16 11.456000 10.528000
19 3584 torch.float16 12.992000 10.272000
19 3584 torch.bfloat16 12.992000 10.272000
19 4096 torch.float16 12.032000 11.744000
19 4096 torch.bfloat16 13.088000 10.336000
19 8192 torch.float16 12.576000 11.264000
19 8192 torch.bfloat16 12.672000 11.392000
19 16384 torch.float16 13.856000 12.704000
19 16384 torch.bfloat16 14.048000 13.728000
99 111 torch.float16 13.088000 10.464000
99 111 torch.bfloat16 13.088000 10.848000
99 500 torch.float16 11.392000 10.400000
99 500 torch.bfloat16 11.584000 10.016000
99 1024 torch.float16 11.552000 10.688000
99 1024 torch.bfloat16 12.960000 10.688000
99 3072 torch.float16 12.288000 10.688000
99 3072 torch.bfloat16 13.376000 11.136000
99 3584 torch.float16 12.640000 12.352000
99 3584 torch.bfloat16 12.640000 11.328000
99 4096 torch.float16 12.704000 11.552000
99 4096 torch.bfloat16 14.080000 11.552000
99 8192 torch.float16 16.192000 14.016000
99 8192 torch.bfloat16 16.160000 14.752000
99 16384 torch.float16 18.208001 15.776001
99 16384 torch.bfloat16 18.176001 17.120000
989 111 torch.float16 12.864000 11.648000
989 111 torch.bfloat16 12.896000 11.680000
989 500 torch.float16 13.408000 13.280000
989 500 torch.bfloat16 14.528000 12.224000
989 1024 torch.float16 18.975999 17.696001
989 1024 torch.bfloat16 18.975999 17.664000
989 3072 torch.float16 23.647999 22.399999
989 3072 torch.bfloat16 23.712000 22.431999
989 3584 torch.float16 24.831999 23.520000
989 3584 torch.bfloat16 24.831999 23.456000
989 4096 torch.float16 23.808001 22.368001
989 4096 torch.bfloat16 23.680000 21.152001
989 8192 torch.float16 35.808001 33.408001
989 8192 torch.bfloat16 34.880001 34.015998
989 16384 torch.float16 64.800002 63.616000
989 16384 torch.bfloat16 64.576000 64.032003

benchmark down cc @zhyncs @hebiao064

FlamingoPg avatar May 22 '25 06:05 FlamingoPg

LFG!

hebiao064 avatar May 24 '25 04:05 hebiao064

Hi @zhyncs, can you help to review this PR? I think it is ready to merge

PopSoda2002 avatar May 29 '25 17:05 PopSoda2002