[Feat] Enable PDL automatically on Hopper architecture
Motivation
In previous versions, FlashInfer 0.2.5 supports norm's PDL, but currently, norm's PDL is disabled by default. This PR modifies the code to enable it automatically on Hopper architecture.
Modifications
- Add
is_hopper_arch()utility function to detect Hopper architecture (compute capability >= 9.0) - Modify rmsnorm, fused_add_rmsnorm, gemma_rmsnorm and gemma_fused_add_rmsnorm functions to auto-enable PDL on Hopper
- Update documentation to reflect these changes
This implementation automatically enables PDL optimization on Hopper GPUs while maintaining backward compatibility by allowing explicit override through the parameter.
Checklist
- [x] Format your code according to the Code Formatting with Pre-Commit.
- [ ] Add unit tests as outlined in the Running Unit Tests.
- [ ] Update documentation / docstrings / example tutorials as needed, according to Writing Documentation.
- [ ] Provide throughput / latency benchmark results and accuracy evaluation results as needed, according to Benchmark and Profiling and Accuracy Results.
- [ ] For reviewers: If you haven't made any contributions to this PR and are only assisting with merging the main branch, please remove yourself as a co-author when merging the PR.
- [ ] Please feel free to join our Slack channel at https://slack.sglang.ai to discuss your PR.
Also please provide the performance benchmark after this enhancement
Also please provide the performance benchmark after this enhancement
Yes, there is another guy who is testing the performance!
Also please provide the performance benchmark after this enhancement
Yes, there is another guy who is testing the performance!
Hi, how's the progress on the benchmark?
Still working bro
Also please provide the performance benchmark after this enhancement
Yes, there is another guy who is testing the performance!
Hi, how's the progress on the benchmark?
Here is my benchmark for testing(test on H100): command:
python3 -m sglang.bench_one_batch --model-path meta-llama/Llama-3.1-8B-Instruct --attention-backend fa3 --batch 16 --input-len 1024 --output-len 10
Before this PR:
After:
Thanks @Fridge003 for helping!
| batch_size | hidden_size | dtype | w/o pdl | w/ pdl |
|---|---|---|---|---|
| 1 | 111 | torch.float16 | 12.224000 | 9.632000 |
| 1 | 111 | torch.bfloat16 | 10.976000 | 11.040000 |
| 1 | 500 | torch.float16 | 11.008000 | 10.144000 |
| 1 | 500 | torch.bfloat16 | 11.424000 | 9.632000 |
| 1 | 1024 | torch.float16 | 11.360000 | 9.696000 |
| 1 | 1024 | torch.bfloat16 | 11.392000 | 9.984000 |
| 1 | 3072 | torch.float16 | 12.544000 | 10.304000 |
| 1 | 3072 | torch.bfloat16 | 11.136000 | 11.264000 |
| 1 | 3584 | torch.float16 | 11.616000 | 10.016000 |
| 1 | 3584 | torch.bfloat16 | 11.296000 | 10.368000 |
| 1 | 4096 | torch.float16 | 11.648000 | 11.456000 |
| 1 | 4096 | torch.bfloat16 | 11.680000 | 11.456000 |
| 1 | 8192 | torch.float16 | 13.248000 | 10.880000 |
| 1 | 8192 | torch.bfloat16 | 12.128000 | 11.936000 |
| 1 | 16384 | torch.float16 | 12.864000 | 11.904000 |
| 1 | 16384 | torch.bfloat16 | 13.216000 | 11.584000 |
| 19 | 111 | torch.float16 | 11.776000 | 11.584000 |
| 19 | 111 | torch.bfloat16 | 12.864000 | 11.584000 |
| 19 | 500 | torch.float16 | 12.544000 | 9.920000 |
| 19 | 500 | torch.bfloat16 | 11.584000 | 9.920000 |
| 19 | 1024 | torch.float16 | 11.776000 | 10.496000 |
| 19 | 1024 | torch.bfloat16 | 12.672000 | 10.048000 |
| 19 | 3072 | torch.float16 | 12.832000 | 11.520000 |
| 19 | 3072 | torch.bfloat16 | 11.456000 | 10.528000 |
| 19 | 3584 | torch.float16 | 12.992000 | 10.272000 |
| 19 | 3584 | torch.bfloat16 | 12.992000 | 10.272000 |
| 19 | 4096 | torch.float16 | 12.032000 | 11.744000 |
| 19 | 4096 | torch.bfloat16 | 13.088000 | 10.336000 |
| 19 | 8192 | torch.float16 | 12.576000 | 11.264000 |
| 19 | 8192 | torch.bfloat16 | 12.672000 | 11.392000 |
| 19 | 16384 | torch.float16 | 13.856000 | 12.704000 |
| 19 | 16384 | torch.bfloat16 | 14.048000 | 13.728000 |
| 99 | 111 | torch.float16 | 13.088000 | 10.464000 |
| 99 | 111 | torch.bfloat16 | 13.088000 | 10.848000 |
| 99 | 500 | torch.float16 | 11.392000 | 10.400000 |
| 99 | 500 | torch.bfloat16 | 11.584000 | 10.016000 |
| 99 | 1024 | torch.float16 | 11.552000 | 10.688000 |
| 99 | 1024 | torch.bfloat16 | 12.960000 | 10.688000 |
| 99 | 3072 | torch.float16 | 12.288000 | 10.688000 |
| 99 | 3072 | torch.bfloat16 | 13.376000 | 11.136000 |
| 99 | 3584 | torch.float16 | 12.640000 | 12.352000 |
| 99 | 3584 | torch.bfloat16 | 12.640000 | 11.328000 |
| 99 | 4096 | torch.float16 | 12.704000 | 11.552000 |
| 99 | 4096 | torch.bfloat16 | 14.080000 | 11.552000 |
| 99 | 8192 | torch.float16 | 16.192000 | 14.016000 |
| 99 | 8192 | torch.bfloat16 | 16.160000 | 14.752000 |
| 99 | 16384 | torch.float16 | 18.208001 | 15.776001 |
| 99 | 16384 | torch.bfloat16 | 18.176001 | 17.120000 |
| 989 | 111 | torch.float16 | 12.864000 | 11.648000 |
| 989 | 111 | torch.bfloat16 | 12.896000 | 11.680000 |
| 989 | 500 | torch.float16 | 13.408000 | 13.280000 |
| 989 | 500 | torch.bfloat16 | 14.528000 | 12.224000 |
| 989 | 1024 | torch.float16 | 18.975999 | 17.696001 |
| 989 | 1024 | torch.bfloat16 | 18.975999 | 17.664000 |
| 989 | 3072 | torch.float16 | 23.647999 | 22.399999 |
| 989 | 3072 | torch.bfloat16 | 23.712000 | 22.431999 |
| 989 | 3584 | torch.float16 | 24.831999 | 23.520000 |
| 989 | 3584 | torch.bfloat16 | 24.831999 | 23.456000 |
| 989 | 4096 | torch.float16 | 23.808001 | 22.368001 |
| 989 | 4096 | torch.bfloat16 | 23.680000 | 21.152001 |
| 989 | 8192 | torch.float16 | 35.808001 | 33.408001 |
| 989 | 8192 | torch.bfloat16 | 34.880001 | 34.015998 |
| 989 | 16384 | torch.float16 | 64.800002 | 63.616000 |
| 989 | 16384 | torch.bfloat16 | 64.576000 | 64.032003 |
benchmark down cc @zhyncs @hebiao064
LFG!
Hi @zhyncs, can you help to review this PR? I think it is ready to merge