sglang icon indicating copy to clipboard operation
sglang copied to clipboard

[DeepSeekV3.2] Enable pure TP & Partial DP Attention

Open YAMY1234 opened this issue 1 month ago • 2 comments

Motivation

DeepSeekV3.2 NSA currently has rough edges when running in pure TP mode (dp_size < tp_size):

This PR makes NSA + pure TP & partial DP Attention a supported and stable configuration for DeepSeekV3.2. Should merge after Upgrade flashmla kernel for NSA tp support #13718.

Sample launch commands:

python -m sglang.launch_server  --model deepseek-ai/DeepSeek-V3.2-Exp  --tp 8
python -m sglang.launch_server  --model deepseek-ai/DeepSeek-V3.2-Exp  --tp 8 --dp 4 --enable-dp-attention

Modifications

  • NSA backend (nsa_backend.py)

    • In _forward_flashmla_sparse(...), pad q’s head dimension to the required multiple (64 on SM90, 128 on SM100+), call the sparse kernel with padded heads, then trim the output back to the original num_heads(TP support).
    • Extend topk_transform(...) with topk_indices_offset_override to accept precomputed ragged offsets(indexer chunking support).
    • Cache device_capability / device_sm_major on init and reuse it in TRTLLM ragged and FlashMLA paths.
  • Server args (server_args.py)

    • Allow dp_size < tp_size for DeepSeekV3.2 NSA:
      • Log a warning for NSA + TP configuration.
  • FlashMLA cmake (flashmla.cmake)

    • Bump FlashMLA submodule to be055fb7df0090fde45f08e9cb5b8b4c0272da73 to use the latest sparse kernel(Avoid crashing with large bs).
  • NSA indexer (nsa_indexer.py)

    • Add _should_chunk_mqa_logits(...) to decide when to chunk fp8 MQA logits based on workload size and free GPU memory.
    • In _get_topk_ragged(...), add a chunked path that:
      • Computes fp8 MQA logits in slices to avoid OOM.
      • Uses topk_indices_offset_override so each chunk can reuse the global ragged offsets safely.
      • Writes into a preallocated (token_nums, index_topk) buffer for all tokens.
  • Fix a shape-mismatch bug in NSA sparse prefill:
    when MLP-sync pads tokens to TP multiples (7→8), the indexer still returns only real-token rows.
    _pad_topk_indices(...)allow padding to maintain #tokens(q) == #tokens(topk_indices), restoring correctness for FlashMLA-sparse under partial DP attention (e.g., TP 8, DP 4).

Accuracy Tests

Launch with

python -m sglang.launch_server  --model deepseek-ai/DeepSeek-V3.2-Exp  --tp 8
GPQA:
Repeat: 8, mean: 0.799                                                                                                            | 26/198 [36:38<1:39:53, 34.84s/it]
Scores: ['0.778', '0.798', '0.823', '0.803', '0.778', '0.803', '0.788', '0.823']
====================█▎                                                                                                           | 26/198 [37:46<6:15:40, 131.05s/it]
Writing report to /tmp/gpqa_deepseek-ai_DeepSeek-V3.2-Exp.html                                                                    | 48/198 [37:35<2:00:18, 48.12s/it]
{'chars': np.float64(14523.828282828283), 'chars:std': np.float64(11247.561643307768), 'score:std': np.float64(0.38147197173296354), 'score': np.float64(0.8232323232323232)}
Writing results to /tmp/gpqa_deepseek-ai_DeepSeek-V3.2-Exp.json
Total latency: 2267.919 s
Score: 0.823

Benchmarking and Profiling

Checklist

YAMY1234 avatar Nov 20 '25 10:11 YAMY1234

Summary of Changes

Hello @YAMY1234, I'm Gemini Code Assist[^1]! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces support for pure Tensor Parallelism (TP) for the DeepSeekV3.2 model's Non-Standard Attention (NSA) mechanism. It addresses compatibility issues with the FlashMLA sparse kernel when num_heads are partitioned across devices in a TP setup, by dynamically padding and unpadding query tensors. Additionally, it configures the server arguments to allow a dedicated pure TP execution mode, optimizing resource utilization for this specific model.

Highlights

  • FlashMLA Sparse Kernel Compatibility: Implemented padding and unpadding logic for query tensors in nsa_backend.py to ensure num_heads aligns with FlashMLA sparse kernel requirements (multiples of 64/128 for Hopper/Blackwell GPUs), especially when Tensor Parallelism reduces the number of heads per device.
  • Pure Tensor Parallelism (TP) for DeepSeek NSA: Modified server_args.py to enable a pure TP mode for DeepSeek NSA models. If dp_size is not explicitly set, it defaults to 1, allowing attention weights to be sharded across TP ranks without data parallelism.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with :thumbsup: and :thumbsdown: on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

[^1]: Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

gemini-code-assist[bot] avatar Nov 20 '25 10:11 gemini-code-assist[bot]

Thanks @YAMY1234~ If your PR is blocked on FlashMLA side, you can create a new branch at https://github.com/sgl-project/FlashMLA. The flashmla kernel now integrated in sglang are built on this repo

Fridge003 avatar Nov 20 '25 18:11 Fridge003

In _forward_flashmla_sparse(...), pad q’s head dimension to the required multiple (64 on SM90, 128 on SM100+)

@YAMY1234 For the Hxx device, the padding head may exhibit poor performance. Could we consider swapping the head and token dimensions via all-to-all (a2a) communication, where each rank increases the head dimension while reducing the token dimension?

xu-yfei avatar Nov 25 '25 08:11 xu-yfei

In _forward_flashmla_sparse(...), pad q’s head dimension to the required multiple (64 on SM90, 128 on SM100+)

@YAMY1234 For the Hxx device, the padding head may exhibit poor performance. Could we consider swapping the head and token dimensions via all-to-all (a2a) communication, where each rank increases the head dimension while reducing the token dimension?

Thanks for the suggestion!

For now, this padding logic only applies to the FlashMLA sparse path under pure TP / partial DP attention. For original normal DP attention path is unchanged — if the layout doesn’t match, it will just fail as before, so there should be no behavioral change there.

I agree that doing an a2a-based head/token swap could be a good follow-up optimization for the TP / partial DP case on Hxx to reduce the padding overhead. For this PR, I think we can keep the change scoped to making pure TP & partial DP attention functional and stable, and we can explore the a2a approach in a separate perf-focused change.

YAMY1234 avatar Nov 25 '25 10:11 YAMY1234

@YAMY1234 Hi, I use your branch (https://github.com/YAMY1234/sglang/tree/dpsk_tp) and get some error in PD:

[2025-11-25 11:34:35 DP0 TP2] Scheduler hit an exception: Traceback (most recent call last):
  File "sglang/python/sglang/srt/managers/scheduler.py", line 2659, in run_scheduler_process
    scheduler.event_loop_overlap_disagg_prefill()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "sglang/python/sglang/srt/disaggregation/prefill.py", line 369, in event_loop_overlap_disagg_prefill
    batch_result = self.run_batch(batch)
                   ^^^^^^^^^^^^^^^^^^^^^
  File "sglang/python/sglang/srt/managers/scheduler.py", line 1952, in run_batch
    batch_result = self.model_worker.forward_batch_generation(
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "sglang/python/sglang/srt/managers/tp_worker.py", line 371, in forward_batch_generation
    logits_output, can_run_cuda_graph = self.model_runner.forward(
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "sglang/python/sglang/srt/model_executor/model_runner.py", line 2277, in forward
    output = self._forward_raw(
             ^^^^^^^^^^^^^^^^^^
  File "sglang/python/sglang/srt/model_executor/model_runner.py", line 2336, in _forward_raw
    ret = self.forward_extend(
          ^^^^^^^^^^^^^^^^^^^^
  File "sglang/python/sglang/srt/model_executor/model_runner.py", line 2222, in forward_extend
    return self.model.forward(
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "sglang/python/sglang/srt/models/deepseek_v2.py", line 3504, in forward
    hidden_states = self.model(
                    ^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "sglang/python/sglang/srt/models/deepseek_v2.py", line 3315, in forward
    hidden_states, residual = layer(
                              ^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "sglang/python/sglang/srt/models/deepseek_v2.py", line 3028, in forward
    hidden_states = self.self_attn(
                    ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "sglang/python/sglang/srt/models/deepseek_v2.py", line 1498, in forward
    return self.forward_core(s)
           ^^^^^^^^^^^^^^^^^^^^
  File "sglang/python/sglang/srt/models/deepseek_v2.py", line 1597, in forward_core
    return self.forward_absorb_core(*inner_state)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "sglang/python/sglang/srt/models/deepseek_v2.py", line 1975, in forward_absorb_core
    attn_output = self.attn_mqa(
                  ^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "sglang/python/sglang/srt/layers/radix_attention.py", line 123, in forward
    return forward_batch.attn_backend.forward(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "sglang/python/sglang/srt/layers/attention/base_attn_backend.py", line 101, in forward
    return self.forward_extend(
           ^^^^^^^^^^^^^^^^^^^^
  File "sglang/python/sglang/srt/layers/attention/nsa_backend.py", line 998, in forward_extend
    return self._forward_flashmla_kv(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "sglang/python/sglang/srt/layers/attention/nsa_backend.py", line 1261, in _forward_flashmla_kv
    o, _ = flash_mla_with_kvcache(
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/sgl_kernel/flash_mla.py", line 112, in flash_mla_with_kvcache
    out, softmax_lse = torch.ops.sgl_kernel.fwd_kvcache_mla.default(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 829, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: seqlens_k must have shape (batch_size)

I launced prefill as

export LWS_LEADER_ADDRESS=xxx
export LWS_GROUP_SIZE=2
export LWS_WORKER_INDEX=0

export SGLANG_DG_CACHE_DIR=/data1/models/deepgemm_cache/deepseek_v31

SGLANG_TBO_DEBUG=0 \
python3 -m sglang.launch_server \
--model-path /data1/models/DeepSeek-V3.2-Exp \
--served-model-name DeepSeek-V3.2-Exp \
--disaggregation-mode prefill \
--disaggregation-transfer-backend mooncake \
--disaggregation-ib-device "mlx5_1,mlx5_2,mlx5_3,mlx5_4" \
--dist-init-addr ${LWS_LEADER_ADDRESS}:20000 --nnodes ${LWS_GROUP_SIZE} --node-rank ${LWS_WORKER_INDEX} \
--tp 16 \
--dp 2 \
--enable-dp-attention \
--mem-fraction-static 0.80 \
--chunked-prefill-size 32768 \
--context-length 131072 \
--trust-remote-code \
--host 0.0.0.0 \
--port 12121 \
--log-level debug \
--enable-cache-report \
--enable-metrics \
--kv-cache-dtype fp8_e4m3 \
--reasoning-parser deepseek-v3 \
--tool-call-parser deepseekv31 \
--chat-template /sgl-workspace/sglang/examples/chat_template/tool_chat_template_deepseekv31.jinja \
--prefill-round-robin-balance \
--page-size 64 \
--decode-log-interval 1

decode as TP16 DP16 EP16 Can you solve this? Thank you very much.

llc-kc avatar Nov 25 '25 11:11 llc-kc

@YAMY1234 Hi, I use your branch (https://github.com/YAMY1234/sglang/tree/dpsk_tp) and get some error in PD:

Thanks for pointing this out!

For now this PR is mainly focused on and validated under the aggregated (agg) setup🥺. In disaggregated (disagg) mode, especially with partial-DP + disagg prefill, we haven’t fully verified all execution paths yet. We will need to add more fixes for disagg to make the partial-DP path stable.

YAMY1234 avatar Nov 25 '25 18:11 YAMY1234

@YAMY1234 Can you add a benchmark for bs=1? Expectedly pure TP should be faster than DP+TP

Fridge003 avatar Nov 27 '25 15:11 Fridge003

@YAMY1234 Can you add a benchmark for bs=1? Expectedly pure TP should be faster than DP+TP

@Fridge003 Added in the PR description~

YAMY1234 avatar Nov 28 '25 18:11 YAMY1234

@YAMY1234 Can you add a benchmark for bs=1? Expectedly pure TP should be faster than DP+TP

@Fridge003 Added in the PR description~

Oh I mean performance benchmark. You can test with python3 -m sglang.test.send_one

Fridge003 avatar Nov 29 '25 22:11 Fridge003

@YAMY1234 Can you add a benchmark for bs=1? Expectedly pure TP should be faster than DP+TP

@Fridge003 Added in the PR description~

Oh I mean performance benchmark. You can test with python3 -m sglang.test.send_one

Sorry my previous added benchmark disappeared😂 might be an saving error. I tested with sglang.bench_serving and looks like TP will be 4-5x faster than DP Attention during prefill in bs=1's situation. Could you take a second look at the PR desc? Thanks!

YAMY1234 avatar Nov 30 '25 19:11 YAMY1234

@YAMY1234 Thanks~ Since this PR will break the usage of deepseek v32, can you please change all the related usage (appending --dp argument) in the test cases (files with prefix test_deepseek_v32). A pure tp test can be added under test/nightly folder. The document also needs update

Fridge003 avatar Nov 30 '25 19:11 Fridge003

@Fridge003 Thanks! Added docs and unittest~

YAMY1234 avatar Nov 30 '25 20:11 YAMY1234

/tag-and-rerun-ci

Fridge003 avatar Nov 30 '25 21:11 Fridge003