sglang icon indicating copy to clipboard operation
sglang copied to clipboard

[WIP] Support double sparsity

Open andy-yang-1 opened this issue 1 year ago • 4 comments

Motivation

  • Support double sparsity (post-training sparse attention) for long context inference in SGLang
  • See paper

Modifications

  • Add triton implementation in sglang/python/sglang/srt/layers/sparse_decode_attention.py
  • Add serving-related parts

Checklist

  • [ ] Format your code according to the Contributor Guide.
  • [ ] Add unit tests as outlined in the Contributor Guide.
  • [ ] Update documentation as needed, including docstrings or example tutorials.

andy-yang-1 avatar Sep 18 '24 22:09 andy-yang-1

Great work. Some tips for rebasing:

  • Recently, we did some refactoring to introduce AttentionBackend. This will make the code of supporting multiple attention backend cleaner https://github.com/sgl-project/sglang/pull/1381
  • Following the above refactor, you can use cuda graph and torch.compile to make it run up to 4x faster (https://github.com/sgl-project/sglang/pull/1401, https://github.com/sgl-project/sglang/pull/1422)

merrymercy avatar Sep 19 '24 09:09 merrymercy

Quick question @andy-yang-1 - Does this PR support just Double Sparsity or DS-Offload as well?

ghost avatar Sep 24 '24 20:09 ghost

@vnkc1 Hi, this PR doesn't support DS-Offload for now. DS-Offload may be integrated in other PR if needed.

andy-yang-1 avatar Sep 24 '24 21:09 andy-yang-1

Is there a plan to merge this PR?

fengyang95 avatar Oct 09 '24 05:10 fengyang95

Yes. It should be merged within one week. @andy-yang-1 please

  1. Resolve the conflicts.
  2. Add an end-to-end accuracy unit test

merrymercy avatar Oct 11 '24 09:10 merrymercy

Please fix the lint error and add an end-to-end accuracy test

merrymercy avatar Oct 14 '24 02:10 merrymercy

Give two example commands and past their results in the description of this PR. This is for tracking the progress. It should be something like this

# baseline
python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3.1-8B-Instruct --batch-size 1 --input 1024 --output 8

# double sparsity
python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3.1-8B-Instruct --batch-size 1 --input 1024 --output 8 --enable-double-sparsity ...

merrymercy avatar Oct 14 '24 07:10 merrymercy

@andy-yang-1 Can you also paste the latency results?

merrymercy avatar Oct 14 '24 08:10 merrymercy

@andy-yang-1 Thanks for the contribution. It is merged.

merrymercy avatar Oct 14 '24 09:10 merrymercy

How does one generate the ds-channel-config to be able to use this?

max99x avatar Oct 14 '24 13:10 max99x

I noticed that CUDA graph is not currently supported. Are there any plans to support it? @andy-yang-1

fengyang95 avatar Oct 16 '24 15:10 fengyang95

@max99x You can use this link to generate channel config file.

@fengyang95 We may support it in the next PR

andy-yang-1 avatar Oct 16 '24 17:10 andy-yang-1

hi @andy-yang-1 Does this support the deepseek-v2 architecture? How can I obtain the config for this structure? I see that the example here https://github.com/andy-yang-1/DoubleSparse/blob/main/evaluation/group_channel_config.py only support llama/mixtral arch.

fengyang95 avatar Oct 18 '24 17:10 fengyang95

@andy-yang-1 I tried running the deepseek-v2 model, but encountered the following issue:

File "/opt/tiger/custome_sglang/python/sglang/srt/layers/attention/double_sparsity_backend.py", line 162, in forward_extend
    k_label = torch.gather(
              ^^^^^^^^^^^^^
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:3 and cuda:0! (when checking argument for argument index in method wrapper_CUDA_gather)
  File "/opt/tiger/custome_sglang/python/sglang/srt/layers/attention/__init__.py", line 49, in forward
    return self.forward_extend(q, k, v, layer, forward_batch)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/tiger/custome_sglang/python/sglang/srt/layers/attention/double_sparsity_backend.py", line 162, in forward_extend
    k_label = torch.gather(
              ^^^^^^^^^^^^^
RuntimeError: Size does not match at dimension 1 expected index [7, 128, 16] to be smaller than self [7, 1, 576] apart from dimension 2

fengyang95 avatar Oct 19 '24 18:10 fengyang95

@fengyang95 I haven't added support for deepseek-v2 model. I may add support for this later

andy-yang-1 avatar Oct 19 '24 21:10 andy-yang-1

@fengyang95 I haven't added support for deepseek-v2 model. I may add support for this later

@andy-yang-1 Thank you very much! Looking forward to support for deepseek-v2 and cuda graph.

fengyang95 avatar Oct 20 '24 04:10 fengyang95

@andy-yang-1 - Loved the paper! I was trying this out and I am facing a few issues generating the config file using the mentioned script.

  1. The line cos, sin = m.rotary_emb(v, seq_len=kv_seq_len) in stat_qk_max_hook of get_calib_qk_feat gives an error
TypeError: LlamaRotaryEmbedding got an unexpected keyword argument 'seq_len'

I replaced it with cos, sin = m.rotary_emb(v, position_ids=position_ids) which works. I'm not sure if that is correct but LlamaRotaryEmbedding indeed doesn't have the seq_len param

  1. In the config file that gets generated, I only get keys of the form model.layers.{layer_num}.self_attn but the config file present in the test folder has keys in the form of model.layers.{layer_num}.self_attn.q_proj, model.layers.{layer_num}.self_attn.k_proj and model.layers.{layer_num}.self_attn.qk_proj. How were these generated? On using my generated config with sglang, I am getting error of the type - Key model.layers.0.self_attn.k_proj was not found.

Any help on how to run this would be appreciated.

shreyansh26 avatar Nov 06 '24 14:11 shreyansh26

@shreyansh26 The first problem is caused by older version of transformers, and I will update the base repo to fix it this week. The q_outlier_config/k_outlier_config is generated with get_calib_feat function, and the qk_outlier_config is generated with get_qk_calib_feat function. You can merge this two config together to get all configs. I will also update it this week.

andy-yang-1 avatar Nov 06 '24 18:11 andy-yang-1

Thank you. There may be another discrepancy, in get_calib_feat, with the following condition, k_proj gets filtered out because of GQA.

if y.shape[-1] != model.config.hidden_size:
    return

But in the Llama-3.1-8B-Instruct config file, k_proj keys are also present.

shreyansh26 avatar Nov 07 '24 07:11 shreyansh26

@shreyansh26 Hi, I have updated the main repo. Can you try with this code?

andy-yang-1 avatar Nov 10 '24 23:11 andy-yang-1

Thank you @andy-yang-1!! This is working perfectly now.

shreyansh26 avatar Nov 11 '24 10:11 shreyansh26

@vnkc1 Hi, this PR doesn't support DS-Offload for now. DS-Offload may be integrated in other PR if needed. Is there a plan to support DS-Offload in Sglang?

yuguo-Jack avatar Nov 23 '24 14:11 yuguo-Jack

Motivation

  • Support double sparsity (post-training sparse attention) for long context inference in SGLang
  • See paper

Modifications

  • Add triton implementation in sglang/python/sglang/srt/layers/sparse_decode_attention.py
  • Add serving-related parts

Speedup Evaluation

Run double sparsity with:

python -m sglang.bench_latency --model-path lmsys/longchat-7b-v1.5-32k \
    --attention-backend triton --disable-cuda-graph \
    --ds-channel-config-path /path/to/lmsys/longchat-7b-v1.5-32k.json \
    --input-len 20000 --output-len 200 \
    --batch-size 3 \
    --enable-double-sparsity \
    --ds-heavy-channel-num 16 \
    --ds-heavy-token-num 1024 \
    --ds-sparse-decode-threshold 0 \
    --max-total-tokens 70000

Benchmark ...
Prefill. latency: 7.83636 s, throughput:   7656.62 token/s
Decode.  latency: 0.02351 s, throughput:    127.58 token/s
Decode.  latency: 0.02124 s, throughput:    141.22 token/s
Decode.  latency: 0.02037 s, throughput:    147.26 token/s
Decode.  latency: 0.01950 s, throughput:    153.81 token/s
Decode.  latency: 0.01935 s, throughput:    155.04 token/s
Decode.  median latency: 0.01923 s, median throughput:    156.04 token/s
Total. latency: 11.821 s, throughput:   5126.36 token/s

Original triton implementation:

python -m sglang.bench_latency --model-path lmsys/longchat-7b-v1.5-32k \
    --attention-backend triton \
    --input-len 20000 --output-len 200 \
    --batch-size 3

Benchmark ...
Prefill. latency: 7.79627 s, throughput:   7695.98 token/s
Decode.  latency: 0.07196 s, throughput:     41.69 token/s
Decode.  latency: 0.06514 s, throughput:     46.05 token/s
Decode.  latency: 0.06475 s, throughput:     46.33 token/s
Decode.  latency: 0.06463 s, throughput:     46.41 token/s
Decode.  latency: 0.06457 s, throughput:     46.46 token/s
Decode.  median latency: 0.06487 s, median throughput:     46.25 token/s
Total. latency: 20.720 s, throughput:   2924.74 token/s

Original flashinfer implementation:

python -m sglang.bench_latency --model-path lmsys/longchat-7b-v1.5-32k \
    --attention-backend flashinfer \
    --input-len 20000 --output-len 200 \
    --batch-size 3

Benchmark ...
Prefill. latency: 5.68892 s, throughput:  10546.83 token/s
Decode.  latency: 0.03240 s, throughput:     92.60 token/s
Decode.  latency: 0.02993 s, throughput:    100.23 token/s
Decode.  latency: 0.02970 s, throughput:    101.01 token/s
Decode.  latency: 0.02959 s, throughput:    101.39 token/s
Decode.  latency: 0.02959 s, throughput:    101.38 token/s
Decode.  median latency: 0.02961 s, median throughput:    101.32 token/s
Total. latency: 11.585 s, throughput:   5231.00 token/s

With Llama-3.1-8B:

# Double Sparsity
python -m sglang.bench_latency --model-path meta-llama/Llama-3.1-8B-Instruct \
    --attention-backend triton \
    --ds-channel-config-path /path/to/meta-llama/Llama-3.1-8B-Instruct.json \
    --input-len 60000 --output-len 200 \
    --batch-size 3 \
    --enable-double-sparsity \
    --ds-heavy-channel-num 32 \
    --ds-heavy-channel-type k \
    --ds-heavy-token-num 3000 \
    --ds-sparse-decode-threshold 0 \
    --max-total-tokens 200000

Benchmark ...
Prefill. latency: 42.96801 s, throughput:   4189.16 token/s
Decode.  latency: 0.02843 s, throughput:    105.50 token/s
Decode.  latency: 0.02518 s, throughput:    119.16 token/s
Decode.  latency: 0.02465 s, throughput:    121.72 token/s
Decode.  latency: 0.02442 s, throughput:    122.84 token/s
Decode.  latency: 0.02434 s, throughput:    123.24 token/s
Decode.  median latency: 0.02421 s, median throughput:    123.90 token/s
Total. latency: 47.793 s, throughput:   3778.77 token/s

# Triton
python -m sglang.bench_latency --model-path meta-llama/Llama-3.1-8B-Instruct \
    --attention-backend triton \
    --input-len 60000 --output-len 200 \
    --batch-size 3 \
    --max-total-tokens 200000

Benchmark ...
Prefill. latency: 43.17160 s, throughput:   4169.41 token/s
Decode.  latency: 0.06359 s, throughput:     47.18 token/s
Decode.  latency: 0.05965 s, throughput:     50.30 token/s
Decode.  latency: 0.05927 s, throughput:     50.62 token/s
Decode.  latency: 0.05906 s, throughput:     50.80 token/s
Decode.  latency: 0.05906 s, throughput:     50.80 token/s
Decode.  median latency: 0.05913 s, median throughput:     50.73 token/s
Total. latency: 54.950 s, throughput:   3286.63 token/s

# Flashinfer
python -m sglang.bench_latency --model-path meta-llama/Llama-3.1-8B-Instruct \
    --attention-backend flashinfer \
    --input-len 60000 --output-len 200 \
    --batch-size 3 \
    --max-total-tokens 200000

Benchmark ...
Prefill. latency: 27.50800 s, throughput:   6543.55 token/s
Decode.  latency: 0.03014 s, throughput:     99.54 token/s
Decode.  latency: 0.02834 s, throughput:    105.86 token/s
Decode.  latency: 0.02821 s, throughput:    106.36 token/s
Decode.  latency: 0.02819 s, throughput:    106.41 token/s
Decode.  latency: 0.02823 s, throughput:    106.28 token/s
Decode.  median latency: 0.02821 s, median throughput:    106.34 token/s
Total. latency: 33.125 s, throughput:   5452.12 token/s

Checklist

  • [ ] Format your code according to the Contributor Guide.
  • [ ] Add unit tests as outlined in the Contributor Guide.
  • [ ] Update documentation as needed, including docstrings or example tutorials.

I found that the throughput of prefill is lower when enable DS attention(from 6543.55 to 4189.16 ). The possible reason is that you use triton as attention-backend. Is it possible to use flashinfer attention in prefill to increase the throughput of prefill.

hcyz33 avatar Jan 13 '25 06:01 hcyz33

Hi, @andy-yang-1 , that is a great work! I have encountered a problem when using double sparsity in the latest version of SGLang. I followed the same command as you, but it turned out to fail. The error log is in the following. Could you help me to fix it ?

CUDA_VISIBLE_DEVICES=0 python -m sglang.bench_one_batch --model-path /home/lyy/model s/Mistral-7B-v0.1 --tensor-parallel-size 1 --attention-backend triton --disable-cuda-graph --ds-channel-config-path /home/zongyi/DoubleSparse/config/mistralai/Mistral-7B-v0.1.json --input-len 20000 --output-len 200 --batch-size 1 --enable-double-sparsity --ds-heavy-channel-num 16 --ds-heavy-token-num 1024 --ds-sparse-decode-threshold 0 --max-total-tokens 70000

/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/torch/cuda/init.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you. import pynvml # type: ignore[import] torch_dtype is deprecated! Use dtype instead! [2025-10-09 14:21:16 TP0] Double sparsity optimization is turned on. Use triton backend without CUDA graph. [2025-10-09 14:21:16 TP0] Init torch distributed begin. [rank0]:[W1009 14:21:17.340194889 ProcessGroupGloo.cpp:514] Warning: Unable to resolve hostname to a (local) address. Using the loopback address as fallback. Manually set the network interface to bind to with GLOO_SOCKET_IFNAME. (function operator()) [Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 [Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 [Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 [Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 [2025-10-09 14:21:17 TP0] Init torch distributed ends. mem usage=0.00 GB [2025-10-09 14:21:18 TP0] Ignore import error when loading sglang.srt.models.kimi_vl: cannot import name 'GELUTanh' from 'transformers.activations' (/home/zongyi/transformers/activations.py) [2025-10-09 14:21:18 TP0] Ignore import error when loading sglang.srt.models.kimi_vl_moonvit: cannot import name 'GELUTanh' from 'transformers.activations' (/home/zongyi/transformers/activations.py) [2025-10-09 14:21:18 TP0] Load weight begin. avail mem=46.94 GB Loading safetensors checkpoint shards: 0% Completed | 0/3 [00:00<?, ?it/s] Loading safetensors checkpoint shards: 33% Completed | 1/3 [00:00<00:01, 1.30it/s] Loading safetensors checkpoint shards: 67% Completed | 2/3 [00:01<00:00, 1.25it/s] Loading safetensors checkpoint shards: 100% Completed | 3/3 [00:02<00:00, 1.19it/s] Loading safetensors checkpoint shards: 100% Completed | 3/3 [00:02<00:00, 1.21it/s]

[2025-10-09 14:21:21 TP0] Load weight end. type=MistralForCausalLM, dtype=torch.bfloat16, avail mem=33.37 GB, mem usage=13.57 GB. [2025-10-09 14:21:21 TP0] Using KV cache dtype: torch.bfloat16 [2025-10-09 14:21:21 TP0] Memory pool end. avail mem=23.76 GB max_total_num_tokens=70000 Warmup ... [rank0]: Traceback (most recent call last): [rank0]: File "", line 198, in _run_module_as_main [rank0]: File "", line 88, in _run_code [rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/sglang/bench_one_batch.py", line 660, in [rank0]: main(server_args, bench_args) [rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/sglang/bench_one_batch.py", line 624, in main [rank0]: work_func(server_args, port_args, bench_args, 0) [rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/sglang/bench_one_batch.py", line 531, in latency_test [rank0]: latency_test_run_once( [rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/sglang/bench_one_batch.py", line 433, in latency_test_run_once [rank0]: next_token_ids, _, batch = extend(reqs, model_runner) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context [rank0]: return func(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/sglang/bench_one_batch.py", line 275, in extend [rank0]: logits_output, _ = model_runner.forward(forward_batch) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/sglang/srt/model_executor/model_runner.py", line 1982, in forward [rank0]: output = self._forward_raw( [rank0]: ^^^^^^^^^^^^^^^^^^ [rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/sglang/srt/model_executor/model_runner.py", line 2033, in _forward_raw [rank0]: ret = self.forward_extend( [rank0]: ^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/sglang/srt/model_executor/model_runner.py", line 1927, in forward_extend [rank0]: return self.model.forward( [rank0]: ^^^^^^^^^^^^^^^^^^^ [rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context [rank0]: return func(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/sglang/srt/models/llama.py", line 469, in forward [rank0]: hidden_states = self.model( [rank0]: ^^^^^^^^^^^ [rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl [rank0]: return self._call_impl(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl [rank0]: return forward_call(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/sglang/srt/models/llama.py", line 342, in forward [rank0]: hidden_states, residual = layer( [rank0]: ^^^^^^ [rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl [rank0]: return self._call_impl(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl [rank0]: return forward_call(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/sglang/srt/models/llama.py", line 266, in forward [rank0]: hidden_states = self.self_attn( [rank0]: ^^^^^^^^^^^^^^^ [rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl [rank0]: return self._call_impl(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl [rank0]: return forward_call(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/sglang/srt/models/llama.py", line 197, in forward [rank0]: attn_output = self.attn(q, k, v, forward_batch) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl [rank0]: return self._call_impl(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl [rank0]: return forward_call(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/sglang/srt/layers/radix_attention.py", line 108, in forward [rank0]: return forward_batch.attn_backend.forward( [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/sglang/srt/layers/attention/base_attn_backend.py", line 82, in forward [rank0]: return self.forward_extend( [rank0]: ^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/root/miniconda3/envs/zy-sgl/lib/python3.11/site-packages/sglang/srt/layers/attention/double_sparsity_backend.py", line 128, in forward_extend [rank0]: k_label = torch.gather( [rank0]: ^^^^^^^^^^^^^ [rank0]: RuntimeError: Size does not match at dimension 1 expected index [20000, 32, 16] to be no larger than self [20000, 8, 128] apart from dimension 2 [rank0]:[W1009 14:21:23.273092827 ProcessGroupNCCL.cpp:1538] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())

alex1720-web avatar Oct 10 '25 01:10 alex1720-web