sglang
sglang copied to clipboard
[1/2] Refactor DeepGeem requant for FP8 Linear on Blackwell
Motivation
Co-Author: @fy1214 Based on pr: https://github.com/sgl-project/sglang/pull/13067
Refactor for MoE requant will be left to the next PR
Modifications
To run DeepGemm on Blackwell, the input scale factor needs to be requantized to ue8m0. The prior codes only consider this requantization for deepseek model classes, and the codes are quite messy. This PR tries to refactor all the requant operations with following modifications:
- Move all the requant operations for FP8 Linear layers to
process_weights_after_loadingmethod. The mark of_executed_weight_requant_ue8m0for inverse transform scale operation is also moved here. - For FP8 MoE layers, its requantization is kept in
deepseek_v2.py. In the future it should be moved toFp8MoEMethod.process_weights_after_loading - q_b_proj in dpsk-fp4-v2 checkpoint requires a weight quant (bf16->fp8) before requantization
In DefaultModelLoader.load_weights_and_postprocess,
- First execute
model.load_weights(weights), where the FP8 MoE layers is requantized for DeepSeek model. - Then execute
quant_method.process_weights_after_loading(module), where the FP8 linear layers are requantized.
Accuracy Tests
All cases are tested on B200 with gsm8k dataset
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1319 --parallel 1319
- DeepSeek V3/3.1 with DeepGeem fp8
python -m sglang.launch_server --model deepseek-ai/DeepSeek-V3.1 --tp 8
Accuracy: 0.958
Invalid: 0.000
Latency: 28.224 s
Output throughput: 4773.001 token/s
- DeepSeek V3/3.1 with Flashinfer fp8 (should skip requantize)
SGLANG_ENABLE_FLASHINFER_FP8_GEMM=1 python -m sglang.launch_server --model deepseek-ai/DeepSeek-V3.1 --tp 8
Accuracy: 0.955
Invalid: 0.000
Latency: 37.858 s
Output throughput: 3578.284 token/s
- DeepSeek-R1 nvfp4 v1
python3 -m sglang.launch_server \
--model-path nvidia/DeepSeek-R1-0528-FP4 \
--served-model-name dsr1 --tp 4 \
--attention-backend trtllm_mla \
--moe-runner-backend flashinfer_trtllm \
--quantization modelopt_fp4 \
--kv-cache-dtype fp8_e4m3
Accuracy: 0.958
Invalid: 0.000
Latency: 27.152 s
Output throughput: 5301.655 token/s
- DeepSeek-R1 nvfp4 v2 + q_b_proj requantized to fp8
SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN=1 \
python3 -m sglang.launch_server \
--model-path nvidia/DeepSeek-R1-0528-FP4-v2 \
--served-model-name dsr1 --tp 4 \
--attention-backend trtllm_mla \
--moe-runner-backend flashinfer_trtllm \
--quantization modelopt_fp4 \
--kv-cache-dtype fp8_e4m3
Accuracy: 0.959
Invalid: 0.000
Latency: 26.797 s
Output throughput: 5371.161 token/s
- Qwen3-30B-FP8
python3 -m sglang.launch_server \
--model-path Qwen/Qwen3-30B-A3B-FP8 \
--tp 2 \
--disable-radix-cache \
--kv-cache-dtype fp8_e4m3
This PR:
Accuracy: 0.696
Invalid: 0.001
Latency: 56.326 s
Output throughput: 6606.698 token/s
Main:
Accuracy: 0.000
Invalid: 1.000
Latency: 70.828 s
Output throughput: 9534.814 token/s
Benchmarking and Profiling
Checklist
- [ ] Format your code according to the Format code with pre-commit.
- [ ] Add unit tests according to the Run and add unit tests.
- [ ] Update documentation according to Write documentations.
- [ ] Provide accuracy and speed benchmark results according to Test the accuracy and Benchmark the speed.
- [ ] Follow the SGLang code style guidance.
- [ ] Work with maintainers to merge your PR. See the PR Merge Process
cc @fzyzcjy @kaixih @fy1214 Please have a look
/tag-and-rerun-ci