sglang
sglang copied to clipboard
[WIP] Support double sparsity
Motivation
- Support double sparsity (post-training sparse attention) for long context inference in SGLang
- See paper
Modifications
- Add triton implementation in
sglang/python/sglang/srt/layers/sparse_decode_attention.py - Add serving-related parts
Checklist
- [ ] Format your code according to the Contributor Guide.
- [ ] Add unit tests as outlined in the Contributor Guide.
- [ ] Update documentation as needed, including docstrings or example tutorials.
Great work. Some tips for rebasing:
- Recently, we did some refactoring to introduce
AttentionBackend. This will make the code of supporting multiple attention backend cleaner https://github.com/sgl-project/sglang/pull/1381 - Following the above refactor, you can use cuda graph and torch.compile to make it run up to 4x faster (https://github.com/sgl-project/sglang/pull/1401, https://github.com/sgl-project/sglang/pull/1422)
Quick question @andy-yang-1 - Does this PR support just Double Sparsity or DS-Offload as well?
@vnkc1 Hi, this PR doesn't support DS-Offload for now. DS-Offload may be integrated in other PR if needed.
Is there a plan to merge this PR?
Yes. It should be merged within one week. @andy-yang-1 please
- Resolve the conflicts.
- Add an end-to-end accuracy unit test
Please fix the lint error and add an end-to-end accuracy test
Give two example commands and past their results in the description of this PR. This is for tracking the progress. It should be something like this
# baseline
python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3.1-8B-Instruct --batch-size 1 --input 1024 --output 8
# double sparsity
python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3.1-8B-Instruct --batch-size 1 --input 1024 --output 8 --enable-double-sparsity ...
@andy-yang-1 Can you also paste the latency results?
@andy-yang-1 Thanks for the contribution. It is merged.
How does one generate the ds-channel-config to be able to use this?
I noticed that CUDA graph is not currently supported. Are there any plans to support it? @andy-yang-1
@max99x You can use this link to generate channel config file.
@fengyang95 We may support it in the next PR
hi @andy-yang-1 Does this support the deepseek-v2 architecture? How can I obtain the config for this structure? I see that the example here https://github.com/andy-yang-1/DoubleSparse/blob/main/evaluation/group_channel_config.py only support llama/mixtral arch.
@andy-yang-1 I tried running the deepseek-v2 model, but encountered the following issue:
File "/opt/tiger/custome_sglang/python/sglang/srt/layers/attention/double_sparsity_backend.py", line 162, in forward_extend
k_label = torch.gather(
^^^^^^^^^^^^^
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:3 and cuda:0! (when checking argument for argument index in method wrapper_CUDA_gather)
File "/opt/tiger/custome_sglang/python/sglang/srt/layers/attention/__init__.py", line 49, in forward
return self.forward_extend(q, k, v, layer, forward_batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/tiger/custome_sglang/python/sglang/srt/layers/attention/double_sparsity_backend.py", line 162, in forward_extend
k_label = torch.gather(
^^^^^^^^^^^^^
RuntimeError: Size does not match at dimension 1 expected index [7, 128, 16] to be smaller than self [7, 1, 576] apart from dimension 2
@fengyang95 I haven't added support for deepseek-v2 model. I may add support for this later
@fengyang95 I haven't added support for deepseek-v2 model. I may add support for this later
@andy-yang-1 Thank you very much! Looking forward to support for deepseek-v2 and cuda graph.
@andy-yang-1 - Loved the paper! I was trying this out and I am facing a few issues generating the config file using the mentioned script.
- The line
cos, sin = m.rotary_emb(v, seq_len=kv_seq_len)instat_qk_max_hookofget_calib_qk_featgives an error
TypeError: LlamaRotaryEmbedding got an unexpected keyword argument 'seq_len'
I replaced it with cos, sin = m.rotary_emb(v, position_ids=position_ids) which works. I'm not sure if that is correct but LlamaRotaryEmbedding indeed doesn't have the seq_len param
- In the config file that gets generated, I only get keys of the form
model.layers.{layer_num}.self_attnbut the config file present in the test folder has keys in the form ofmodel.layers.{layer_num}.self_attn.q_proj,model.layers.{layer_num}.self_attn.k_projandmodel.layers.{layer_num}.self_attn.qk_proj. How were these generated? On using my generated config with sglang, I am getting error of the type -Key model.layers.0.self_attn.k_proj was not found.
Any help on how to run this would be appreciated.
@shreyansh26 The first problem is caused by older version of transformers, and I will update the base repo to fix it this week. The q_outlier_config/k_outlier_config is generated with get_calib_feat function, and the qk_outlier_config is generated with get_qk_calib_feat function. You can merge this two config together to get all configs. I will also update it this week.
Thank you.
There may be another discrepancy, in get_calib_feat, with the following condition, k_proj gets filtered out because of GQA.
if y.shape[-1] != model.config.hidden_size:
return
But in the Llama-3.1-8B-Instruct config file, k_proj keys are also present.
@shreyansh26 Hi, I have updated the main repo. Can you try with this code?
Thank you @andy-yang-1!! This is working perfectly now.
@vnkc1 Hi, this PR doesn't support DS-Offload for now. DS-Offload may be integrated in other PR if needed. Is there a plan to support DS-Offload in Sglang?
Motivation
- Support double sparsity (post-training sparse attention) for long context inference in SGLang
- See paper
Modifications
- Add triton implementation in
sglang/python/sglang/srt/layers/sparse_decode_attention.py- Add serving-related parts
Speedup Evaluation
Run double sparsity with:
python -m sglang.bench_latency --model-path lmsys/longchat-7b-v1.5-32k \ --attention-backend triton --disable-cuda-graph \ --ds-channel-config-path /path/to/lmsys/longchat-7b-v1.5-32k.json \ --input-len 20000 --output-len 200 \ --batch-size 3 \ --enable-double-sparsity \ --ds-heavy-channel-num 16 \ --ds-heavy-token-num 1024 \ --ds-sparse-decode-threshold 0 \ --max-total-tokens 70000 Benchmark ... Prefill. latency: 7.83636 s, throughput: 7656.62 token/s Decode. latency: 0.02351 s, throughput: 127.58 token/s Decode. latency: 0.02124 s, throughput: 141.22 token/s Decode. latency: 0.02037 s, throughput: 147.26 token/s Decode. latency: 0.01950 s, throughput: 153.81 token/s Decode. latency: 0.01935 s, throughput: 155.04 token/s Decode. median latency: 0.01923 s, median throughput: 156.04 token/s Total. latency: 11.821 s, throughput: 5126.36 token/sOriginal triton implementation:
python -m sglang.bench_latency --model-path lmsys/longchat-7b-v1.5-32k \ --attention-backend triton \ --input-len 20000 --output-len 200 \ --batch-size 3 Benchmark ... Prefill. latency: 7.79627 s, throughput: 7695.98 token/s Decode. latency: 0.07196 s, throughput: 41.69 token/s Decode. latency: 0.06514 s, throughput: 46.05 token/s Decode. latency: 0.06475 s, throughput: 46.33 token/s Decode. latency: 0.06463 s, throughput: 46.41 token/s Decode. latency: 0.06457 s, throughput: 46.46 token/s Decode. median latency: 0.06487 s, median throughput: 46.25 token/s Total. latency: 20.720 s, throughput: 2924.74 token/sOriginal flashinfer implementation:
python -m sglang.bench_latency --model-path lmsys/longchat-7b-v1.5-32k \ --attention-backend flashinfer \ --input-len 20000 --output-len 200 \ --batch-size 3 Benchmark ... Prefill. latency: 5.68892 s, throughput: 10546.83 token/s Decode. latency: 0.03240 s, throughput: 92.60 token/s Decode. latency: 0.02993 s, throughput: 100.23 token/s Decode. latency: 0.02970 s, throughput: 101.01 token/s Decode. latency: 0.02959 s, throughput: 101.39 token/s Decode. latency: 0.02959 s, throughput: 101.38 token/s Decode. median latency: 0.02961 s, median throughput: 101.32 token/s Total. latency: 11.585 s, throughput: 5231.00 token/sWith Llama-3.1-8B:
# Double Sparsity python -m sglang.bench_latency --model-path meta-llama/Llama-3.1-8B-Instruct \ --attention-backend triton \ --ds-channel-config-path /path/to/meta-llama/Llama-3.1-8B-Instruct.json \ --input-len 60000 --output-len 200 \ --batch-size 3 \ --enable-double-sparsity \ --ds-heavy-channel-num 32 \ --ds-heavy-channel-type k \ --ds-heavy-token-num 3000 \ --ds-sparse-decode-threshold 0 \ --max-total-tokens 200000 Benchmark ... Prefill. latency: 42.96801 s, throughput: 4189.16 token/s Decode. latency: 0.02843 s, throughput: 105.50 token/s Decode. latency: 0.02518 s, throughput: 119.16 token/s Decode. latency: 0.02465 s, throughput: 121.72 token/s Decode. latency: 0.02442 s, throughput: 122.84 token/s Decode. latency: 0.02434 s, throughput: 123.24 token/s Decode. median latency: 0.02421 s, median throughput: 123.90 token/s Total. latency: 47.793 s, throughput: 3778.77 token/s # Triton python -m sglang.bench_latency --model-path meta-llama/Llama-3.1-8B-Instruct \ --attention-backend triton \ --input-len 60000 --output-len 200 \ --batch-size 3 \ --max-total-tokens 200000 Benchmark ... Prefill. latency: 43.17160 s, throughput: 4169.41 token/s Decode. latency: 0.06359 s, throughput: 47.18 token/s Decode. latency: 0.05965 s, throughput: 50.30 token/s Decode. latency: 0.05927 s, throughput: 50.62 token/s Decode. latency: 0.05906 s, throughput: 50.80 token/s Decode. latency: 0.05906 s, throughput: 50.80 token/s Decode. median latency: 0.05913 s, median throughput: 50.73 token/s Total. latency: 54.950 s, throughput: 3286.63 token/s # Flashinfer python -m sglang.bench_latency --model-path meta-llama/Llama-3.1-8B-Instruct \ --attention-backend flashinfer \ --input-len 60000 --output-len 200 \ --batch-size 3 \ --max-total-tokens 200000 Benchmark ... Prefill. latency: 27.50800 s, throughput: 6543.55 token/s Decode. latency: 0.03014 s, throughput: 99.54 token/s Decode. latency: 0.02834 s, throughput: 105.86 token/s Decode. latency: 0.02821 s, throughput: 106.36 token/s Decode. latency: 0.02819 s, throughput: 106.41 token/s Decode. latency: 0.02823 s, throughput: 106.28 token/s Decode. median latency: 0.02821 s, median throughput: 106.34 token/s Total. latency: 33.125 s, throughput: 5452.12 token/sChecklist
- [ ] Format your code according to the Contributor Guide.
- [ ] Add unit tests as outlined in the Contributor Guide.
- [ ] Update documentation as needed, including docstrings or example tutorials.
I found that the throughput of prefill is lower when enable DS attention(from 6543.55 to 4189.16 ). The possible reason is that you use triton as attention-backend. Is it possible to use flashinfer attention in prefill to increase the throughput of prefill.
Hi, @andy-yang-1 , that is a great work! I have encountered a problem when using double sparsity in the latest version of SGLang. I followed the same command as you, but it turned out to fail. The error log is in the following. Could you help me to fix it ?
CUDA_VISIBLE_DEVICES=0 python -m sglang.bench_one_batch --model-path /home/lyy/model s/Mistral-7B-v0.1 --tensor-parallel-size 1 --attention-backend triton --disable-cuda-graph --ds-channel-config-path /home/zongyi/DoubleSparse/config/mistralai/Mistral-7B-v0.1.json --input-len 20000 --output-len 200 --batch-size 1 --enable-double-sparsity --ds-heavy-channel-num 16 --ds-heavy-token-num 1024 --ds-sparse-decode-threshold 0 --max-total-tokens 70000
/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/torch/cuda/init.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
import pynvml # type: ignore[import]
torch_dtype is deprecated! Use dtype instead!
[2025-10-09 14:21:16 TP0] Double sparsity optimization is turned on. Use triton backend without CUDA graph.
[2025-10-09 14:21:16 TP0] Init torch distributed begin.
[rank0]:[W1009 14:21:17.340194889 ProcessGroupGloo.cpp:514] Warning: Unable to resolve hostname to a (local) address. Using the loopback address as fallback. Manually set the network interface to bind to with GLOO_SOCKET_IFNAME. (function operator())
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[2025-10-09 14:21:17 TP0] Init torch distributed ends. mem usage=0.00 GB
[2025-10-09 14:21:18 TP0] Ignore import error when loading sglang.srt.models.kimi_vl: cannot import name 'GELUTanh' from 'transformers.activations' (/home/zongyi/transformers/activations.py)
[2025-10-09 14:21:18 TP0] Ignore import error when loading sglang.srt.models.kimi_vl_moonvit: cannot import name 'GELUTanh' from 'transformers.activations' (/home/zongyi/transformers/activations.py)
[2025-10-09 14:21:18 TP0] Load weight begin. avail mem=46.94 GB
Loading safetensors checkpoint shards: 0% Completed | 0/3 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 33% Completed | 1/3 [00:00<00:01, 1.30it/s]
Loading safetensors checkpoint shards: 67% Completed | 2/3 [00:01<00:00, 1.25it/s]
Loading safetensors checkpoint shards: 100% Completed | 3/3 [00:02<00:00, 1.19it/s]
Loading safetensors checkpoint shards: 100% Completed | 3/3 [00:02<00:00, 1.21it/s]
[2025-10-09 14:21:21 TP0] Load weight end. type=MistralForCausalLM, dtype=torch.bfloat16, avail mem=33.37 GB, mem usage=13.57 GB.
[2025-10-09 14:21:21 TP0] Using KV cache dtype: torch.bfloat16
[2025-10-09 14:21:21 TP0] Memory pool end. avail mem=23.76 GB
max_total_num_tokens=70000
Warmup ...
[rank0]: Traceback (most recent call last):
[rank0]: File "