[Perf] Tunings for SM100 FP8 CUTLASS kernel
I noticed that the FP8 CUTLASS kernel for Blackwell only had one default set of configs. This PR adds new configs for small M < 128.
For Llama 8B on B200, these tunings offer a:
- 1.7 to 2.5x speedup at M<64
- 1.1 to 1.3x speedup at 64<=M<128
Kernel benchmarks using https://github.com/vllm-project/vllm/pull/17126
# B200 original tunings
python benchmarks/kernels/bench_fp8_gemm.py --model meta-llama/Llama-3.1-8B-Instruct --tp-sizes 1
meta-llama/Llama-3.1-8B-Instruct, N=6144 K=4096, BF16 vs FP8 GEMMs GB/s:
batch_size torch-bf16 fp8-tensor-w-tensor-a fp8-channel-w-token-a fp8-tensor-w-tensor-a-noquant fp8-channel-w-token-a-noquant
0 1.0 11.842750 4.384695 4.141685 5.059773 4.960429
1 16.0 192.371772 85.477309 79.608508 96.735811 96.406114
2 64.0 807.719071 589.708519 516.921588 758.622920 752.988363
3 128.0 1493.427964 1164.192322 1022.805103 1505.913815 1492.575158
4 256.0 2330.405171 2220.442815 1813.278596 3008.296875 2981.101615
5 512.0 2907.547764 2793.742472 2203.694550 3672.868551 3653.372327
6 1024.0 3185.960190 3869.274803 2936.256100 5253.748472 5213.397951
7 2048.0 3372.689911 4142.692217 3109.779229 5727.176215 5689.243846
8 4096.0 3462.188289 4291.145308 3204.745041 5911.353402 5899.836283
9 8192.0 3529.249912 4397.600739 3282.537443 6098.495194 6085.388654
10 16384.0 3566.596828 4546.730027 3373.615805 6298.621054 6287.915519
meta-llama/Llama-3.1-8B-Instruct, N=4096 K=4096, BF16 vs FP8 GEMMs GB/s:
batch_size torch-bf16 fp8-tensor-w-tensor-a fp8-channel-w-token-a fp8-tensor-w-tensor-a-noquant fp8-channel-w-token-a-noquant
0 1.0 5.481086 3.994605 3.679018 4.808255 4.673826
1 16.0 145.731275 75.527900 68.386568 91.504401 90.064710
2 64.0 575.803168 394.026762 345.690693 507.879320 504.012236
3 128.0 1139.410454 779.473527 685.449889 1007.821323 998.688924
4 256.0 1875.656947 1486.448134 1213.644060 2014.891023 1996.817344
5 512.0 2763.397677 2596.513320 1876.577101 3946.713561 3910.224984
6 1024.0 3142.721171 3239.147685 2321.719735 4832.091875 4807.995344
7 2048.0 3318.816545 3651.365389 2562.171046 5485.452742 5446.915454
8 4096.0 3414.255510 3752.133901 2590.631646 5823.098124 5807.590771
9 8192.0 3504.274870 3834.227574 2653.123452 6003.879744 5993.921400
10 16384.0 4048.244939 3953.554872 2722.017045 6171.108894 6164.274055
meta-llama/Llama-3.1-8B-Instruct, N=28672 K=4096, BF16 vs FP8 GEMMs GB/s:
batch_size torch-bf16 fp8-tensor-w-tensor-a fp8-channel-w-token-a fp8-tensor-w-tensor-a-noquant fp8-channel-w-token-a-noquant
0 1.0 12.157825 5.612388 5.539948 6.181624 6.000656
1 16.0 172.970068 114.972518 112.622180 120.732156 121.298006
2 64.0 674.595646 943.557386 905.512283 1050.581139 1062.964463
3 128.0 1300.962379 1858.124579 1761.151736 2047.242680 2042.379529
4 256.0 2216.287867 3554.804884 3329.506996 4014.308414 3998.976858
5 512.0 2726.096815 4795.019635 4364.503526 5341.365889 5324.928601
6 1024.0 2959.186345 5294.689937 4836.956583 5821.938932 5814.327508
7 2048.0 3761.815356 5591.323764 5110.608734 6088.998704 6083.975578
8 4096.0 3574.236814 5806.798749 5289.014726 6308.693639 6300.039740
9 8192.0 4019.744188 5913.920427 5381.996320 6414.258701 6403.837270
10 16384.0 4106.402587 5836.563139 5324.548319 6329.917220 6323.172853
meta-llama/Llama-3.1-8B-Instruct, N=4096 K=14336, BF16 vs FP8 GEMMs GB/s:
batch_size torch-bf16 fp8-tensor-w-tensor-a fp8-channel-w-token-a fp8-tensor-w-tensor-a-noquant fp8-channel-w-token-a-noquant
0 1.0 9.777438 6.095286 5.844354 6.928244 6.744669
1 16.0 162.025163 154.025715 144.935378 185.241624 184.836429
2 64.0 616.911342 611.287938 576.028417 740.686773 738.542393
3 128.0 1204.949288 1216.822830 1146.637308 1476.530711 1471.302346
4 256.0 2111.100423 2293.131110 2051.262051 2952.089686 2942.423441
5 512.0 2937.260107 3796.200098 3114.371919 5730.980802 5704.233235
6 1024.0 3488.413658 4239.281076 3446.170900 6216.238000 6212.696169
7 2048.0 3685.025884 4490.287887 3635.015494 6539.948924 6529.935247
8 4096.0 3734.385998 4532.309092 3662.371663 6636.056913 6629.569313
9 8192.0 3758.236603 4645.041087 3762.776234 6743.351536 6718.151235
10 16384.0 4306.204970 4738.751066 3832.668281 6813.248826 6805.834591
# B200 new tunings
python benchmarks/kernels/bench_fp8_gemm.py --model meta-llama/Llama-3.1-8B-Instruct --tp-sizes 1
meta-llama/Llama-3.1-8B-Instruct, N=6144 K=4096, BF16 vs FP8 GEMMs GB/s:
batch_size torch-bf16 fp8-tensor-w-tensor-a fp8-channel-w-token-a fp8-tensor-w-tensor-a-noquant fp8-channel-w-token-a-noquant
0 1.0 11.846480 8.452990 7.572770 10.641609 10.534006
1 16.0 192.336576 141.348902 125.896661 177.382610 176.327304
2 64.0 809.054397 676.540533 583.294353 904.338200 901.570760
3 128.0 1493.297679 1353.990985 1177.501355 1843.712204 1838.021954
4 256.0 2350.962645 2220.517499 1815.399353 3008.279996 2980.624284
5 512.0 2903.143544 2787.914646 2203.210421 3672.087705 3653.324502
6 1024.0 3206.446347 3863.193810 2935.647985 5250.389648 5210.453726
7 2048.0 3371.772052 4139.752282 3097.937554 5695.430751 5665.540960
8 4096.0 3462.810635 4292.139714 3208.626955 5912.431026 5900.400438
9 8192.0 3530.033816 4397.610400 3283.561232 6095.620584 6086.099140
10 16384.0 3566.388107 4547.710555 3374.347531 6297.712858 6287.613786
meta-llama/Llama-3.1-8B-Instruct, N=4096 K=4096, BF16 vs FP8 GEMMs GB/s:
batch_size torch-bf16 fp8-tensor-w-tensor-a fp8-channel-w-token-a fp8-tensor-w-tensor-a-noquant fp8-channel-w-token-a-noquant
0 1.0 5.481237 7.194318 6.265508 9.530527 9.511873
1 16.0 145.717362 114.108143 99.566867 152.901809 152.985032
2 64.0 575.814872 449.980812 389.597206 602.248063 600.669616
3 128.0 1139.590749 905.278908 787.789674 1234.155324 1229.686850
4 256.0 1876.291317 1486.630754 1213.623961 2015.033919 1996.965312
5 512.0 2764.009823 2591.524744 1879.343830 3946.809025 3909.465672
6 1024.0 3142.383890 3230.337927 2321.113332 4830.144592 4804.234807
7 2048.0 3318.579790 3651.632276 2557.730213 5485.786728 5447.503083
8 4096.0 3414.370536 3749.333637 2590.511600 5821.179268 5806.526663
9 8192.0 3504.070655 3834.735907 2654.077070 6004.740320 5992.704228
10 16384.0 4047.938574 3953.740013 2722.086526 6170.869834 6163.802150
meta-llama/Llama-3.1-8B-Instruct, N=28672 K=4096, BF16 vs FP8 GEMMs GB/s:
batch_size torch-bf16 fp8-tensor-w-tensor-a fp8-channel-w-token-a fp8-tensor-w-tensor-a-noquant fp8-channel-w-token-a-noquant
0 1.0 12.139843 13.499578 13.086876 14.939551 14.847452
1 16.0 172.540520 219.474892 210.204189 239.697254 230.664293
2 64.0 674.577865 1030.361791 981.054381 1149.836193 1148.147840
3 128.0 1301.165223 1872.693659 1806.220627 2081.156220 2123.675357
4 256.0 2216.100916 3552.279693 3336.364808 4014.755649 3998.710144
5 512.0 2725.949811 4792.585207 4363.568374 5341.086466 5326.349901
6 1024.0 2959.125081 5293.626173 4835.906799 5821.708743 5813.863106
7 2048.0 3761.956557 5591.420874 5110.979048 6089.098791 6083.292157
8 4096.0 3574.307512 5807.307181 5288.928725 6308.724765 6300.209702
9 8192.0 4019.744388 5914.312745 5382.379939 6414.279709 6404.333069
10 16384.0 4106.374801 5836.666966 5324.534539 6329.588982 6323.079665
meta-llama/Llama-3.1-8B-Instruct, N=4096 K=14336, BF16 vs FP8 GEMMs GB/s:
batch_size torch-bf16 fp8-tensor-w-tensor-a fp8-channel-w-token-a fp8-tensor-w-tensor-a-noquant fp8-channel-w-token-a-noquant
0 1.0 9.733832 12.054116 11.104702 15.095716 15.057972
1 16.0 161.979220 189.099701 175.757232 238.928259 238.785025
2 64.0 616.771351 751.830865 698.306492 956.291017 954.205153
3 128.0 1204.865625 1483.958057 1379.201876 1886.236952 1883.095434
4 256.0 2110.901491 2295.331478 2051.189524 2952.193198 2942.234970
5 512.0 2937.621666 3799.227788 3114.414028 5730.719186 5704.042751
6 1024.0 3488.194848 4240.428960 3442.777987 6216.468683 6211.739077
7 2048.0 3684.893025 4490.729893 3634.670625 6540.363564 6530.731937
8 4096.0 3734.498831 4532.994968 3662.073259 6636.363393 6630.525610
9 8192.0 3757.974111 4645.556715 3763.901169 6743.521712 6721.473664
10 16384.0 4305.834102 4738.611024 3832.780746 6811.809892 6806.390517
👋 Hi! Thank you for contributing to the vLLM project.
💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.
Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.
To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.
🚀
cc: @chenyang78 @drisspg