sglang icon indicating copy to clipboard operation
sglang copied to clipboard

[ROCm] Enable Fused MLA Triton kernel for DeepSeekV3

Open lcskrishna opened this issue 10 months ago • 1 comments

This PR introduces the concept of Fused MLA decode Triton kernel on ROCm. To use this feature one has to use the env variable : SGLANG_ROCM_FUSED_DECODE_MLA=1.

Triton Kernel authors: @juuso-oskari (Korhonen, Juuso), @Chi-Chu319 (Tianxing Wu) and @vgokhale (Gokhale Vinayak)

lcskrishna avatar Jan 31 '25 17:01 lcskrishna

cc: @sunway513

lcskrishna avatar Feb 13 '25 14:02 lcskrishna

@HaiShaw This is ready for review.

lcskrishna avatar Feb 24 '25 09:02 lcskrishna

@lcskrishna can you share some perf uplift or comparison here? Thanks!

HaiShaw avatar Feb 24 '25 10:02 HaiShaw

@saienduri can you help for a look, why all the CIs are skipped and unable to rerun?

HaiShaw avatar Feb 24 '25 10:02 HaiShaw

@lcskrishna can you share some perf uplift or comparison here? Thanks!

@HaiShaw Here are some basic benchmarks with DeepSeekV3 with this feature.

Without Fusion:

% python -m sglang.bench_one_batch --batch-size 4 --input 2048 --output 256 --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code

Decode.  median latency: 0.03850 s, median throughput:    103.89 token/s
Total. latency: 10.656 s, throughput:    864.89 token/s

With this feature:

% SGLANG_ROCM_FUSED_DECODE_MLA=1  python -m sglang.bench_one_batch --batch-size 4 --input 2048 --output 256 --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code

Decode.  median latency: 0.03581 s, median throughput:    111.71 token/s
Total. latency:  9.970 s, throughput:    924.39 token/s

lcskrishna avatar Feb 24 '25 12:02 lcskrishna