[Feature] Apply Cublas Grouped Gemm kernel
Motivation
#3323
Grouped Gemm kernel added in Cublas 12.5 is useful. It can be applied to MoE EP layer/Lora layer for acceleration.
Modifications
- Add
cublas_grouped_gemmin sgl-kernel library, and provides accuracy test/benchmark script. - Update document for this feature.
Environment:
Torch 2.5.1, Cuda 12.5, Cublas 12.5.3.2, sglang 0.4.3 Since sglang doesn't support torch 2.6 yet, to build the environment:
- First make sure the Cuda version is >= 12.5 with
nvcc -V - Then install sglang as the official document does
- Reinstall cublas 12.5 through
pip install nvidia-cublas-cu12==12.5.3.2so that the cublas is upgraded - Compile the new sgl-kernel library.
Accuracy Test
python3 sgl-kernel/tests/test_cublas_grouped_gemm.py
Kernel Benchmark
Deepseek V2 setting
On Deepseek V2 setting with TP Size = 8 (Group Size=20), N = 3072, K = 5120:
!python3 sgl-kernel/benchmark/bench_cublas_grouped_gemm.py --models DeepSeek-V2 --tp-size 8
Result in GB per second:
Deepseek V2-Lite setting
On Deepseek V2 setting with TP Size = 2 (Group Size=32), N = 2816, K = 2048:
!python3 sgl-kernel/benchmark/bench_cublas_grouped_gemm.py --models DeepSeek-V2-Lite --tp-size 2
Result in GB per second:
Checklist
- [x] Format your code according to the Code Formatting with Pre-Commit.
- [x] Add unit tests as outlined in the Running Unit Tests.
- [x] Update documentation / docstrings / example tutorials as needed, according to Writing Documentation.
- [x] Provide throughput / latency benchmark results and accuracy evaluation results as needed, according to Benchmark and Profiling and Accuracy Results.
- [x] For reviewers: If you haven't made any contributions to this PR and are only assisting with merging the main branch, please remove yourself as a co-author when merging the PR.
- [x] Please feel free to join our Slack channel at https://slack.sglang.ai to discuss your PR.
Since pytorch 2.5.1 only supports cuda12.4 in official docs, and we can not change pytorch version easily, we need to update doc to guide user to reinstall pytorch if they want to use group gemm to accelerate their models.
LGTM cc @zhyncs
amazing work!