sglang
sglang copied to clipboard
[ROCm] Enable Fused MLA Triton kernel for DeepSeekV3
This PR introduces the concept of Fused MLA decode Triton kernel on ROCm. To use this feature one has to use the env variable : SGLANG_ROCM_FUSED_DECODE_MLA=1.
Triton Kernel authors: @juuso-oskari (Korhonen, Juuso), @Chi-Chu319 (Tianxing Wu) and @vgokhale (Gokhale Vinayak)
cc: @sunway513
@HaiShaw This is ready for review.
@lcskrishna can you share some perf uplift or comparison here? Thanks!
@saienduri can you help for a look, why all the CIs are skipped and unable to rerun?
@lcskrishna can you share some perf uplift or comparison here? Thanks!
@HaiShaw Here are some basic benchmarks with DeepSeekV3 with this feature.
Without Fusion:
% python -m sglang.bench_one_batch --batch-size 4 --input 2048 --output 256 --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code
Decode. median latency: 0.03850 s, median throughput: 103.89 token/s
Total. latency: 10.656 s, throughput: 864.89 token/s
With this feature:
% SGLANG_ROCM_FUSED_DECODE_MLA=1 python -m sglang.bench_one_batch --batch-size 4 --input 2048 --output 256 --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code
Decode. median latency: 0.03581 s, median throughput: 111.71 token/s
Total. latency: 9.970 s, throughput: 924.39 token/s