sglang icon indicating copy to clipboard operation
sglang copied to clipboard

[Feature] Apply Cublas Grouped Gemm kernel

Open Fridge003 opened this issue 10 months ago • 1 comments

Motivation

#3323

Grouped Gemm kernel added in Cublas 12.5 is useful. It can be applied to MoE EP layer/Lora layer for acceleration.

Modifications

  • Add cublas_grouped_gemm in sgl-kernel library, and provides accuracy test/benchmark script.
  • Update document for this feature.

Environment:

Torch 2.5.1, Cuda 12.5, Cublas 12.5.3.2, sglang 0.4.3 Since sglang doesn't support torch 2.6 yet, to build the environment:

  1. First make sure the Cuda version is >= 12.5 with nvcc -V
  2. Then install sglang as the official document does
  3. Reinstall cublas 12.5 through pip install nvidia-cublas-cu12==12.5.3.2 so that the cublas is upgraded
  4. Compile the new sgl-kernel library.

Accuracy Test

python3 sgl-kernel/tests/test_cublas_grouped_gemm.py 

Kernel Benchmark

Deepseek V2 setting

On Deepseek V2 setting with TP Size = 8 (Group Size=20), N = 3072, K = 5120:

!python3 sgl-kernel/benchmark/bench_cublas_grouped_gemm.py --models DeepSeek-V2 --tp-size 8

Result in GB per second: 截屏2025-02-17 00 56 17

Deepseek V2-Lite setting

On Deepseek V2 setting with TP Size = 2 (Group Size=32), N = 2816, K = 2048:

!python3 sgl-kernel/benchmark/bench_cublas_grouped_gemm.py --models DeepSeek-V2-Lite --tp-size 2

Result in GB per second: 截屏2025-02-17 01 00 33

Checklist

  • [x] Format your code according to the Code Formatting with Pre-Commit.
  • [x] Add unit tests as outlined in the Running Unit Tests.
  • [x] Update documentation / docstrings / example tutorials as needed, according to Writing Documentation.
  • [x] Provide throughput / latency benchmark results and accuracy evaluation results as needed, according to Benchmark and Profiling and Accuracy Results.
  • [x] For reviewers: If you haven't made any contributions to this PR and are only assisting with merging the main branch, please remove yourself as a co-author when merging the PR.
  • [x] Please feel free to join our Slack channel at https://slack.sglang.ai to discuss your PR.

Fridge003 avatar Feb 17 '25 08:02 Fridge003

Since pytorch 2.5.1 only supports cuda12.4 in official docs, and we can not change pytorch version easily, we need to update doc to guide user to reinstall pytorch if they want to use group gemm to accelerate their models.

yizhang2077 avatar Feb 17 '25 17:02 yizhang2077

LGTM cc @zhyncs

yizhang2077 avatar Feb 18 '25 02:02 yizhang2077

amazing work!

zhyncs avatar Feb 18 '25 07:02 zhyncs