[DeepSeekV3.2] Enable pure TP & Partial DP Attention
Motivation
DeepSeekV3.2 NSA currently has rough edges when running in pure TP mode (dp_size < tp_size):
- FlashMLA sparse can see an invalid
num_headsper rank after TP sharding. - NSA's get_mla_metadata crashes with TP + large bs (Fixed with Fix FlashMLA Shared-Memory Overflow in SGLang's Pure-TP Mode with Low-SMEM Fallback Scheduler #2)
- The NSA fp8 MQA indexer may OOM when building a full
(num_q, num_k)logits matrix in large pure-TP batches.
This PR makes NSA + pure TP & partial DP Attention a supported and stable configuration for DeepSeekV3.2. Should merge after Upgrade flashmla kernel for NSA tp support #13718.
Sample launch commands:
python -m sglang.launch_server --model deepseek-ai/DeepSeek-V3.2-Exp --tp 8
python -m sglang.launch_server --model deepseek-ai/DeepSeek-V3.2-Exp --tp 8 --dp 4 --enable-dp-attention
Modifications
-
NSA backend (
nsa_backend.py)- In
_forward_flashmla_sparse(...), padq’s head dimension to the required multiple (64 on SM90, 128 on SM100+), call the sparse kernel with padded heads, then trim the output back to the originalnum_heads(TP support). - Extend
topk_transform(...)withtopk_indices_offset_overrideto accept precomputed ragged offsets(indexer chunking support). - Cache
device_capability/device_sm_majoron init and reuse it in TRTLLM ragged and FlashMLA paths.
- In
-
Server args (
server_args.py)- Allow
dp_size < tp_sizefor DeepSeekV3.2 NSA:- Log a warning for NSA + TP configuration.
- Allow
-
FlashMLA cmake (
flashmla.cmake)- Bump FlashMLA submodule to
be055fb7df0090fde45f08e9cb5b8b4c0272da73to use the latest sparse kernel(Avoid crashing with large bs).
- Bump FlashMLA submodule to
-
NSA indexer (
nsa_indexer.py)- Add
_should_chunk_mqa_logits(...)to decide when to chunk fp8 MQA logits based on workload size and free GPU memory. - In
_get_topk_ragged(...), add a chunked path that:- Computes fp8 MQA logits in slices to avoid OOM.
- Uses
topk_indices_offset_overrideso each chunk can reuse the global ragged offsets safely. - Writes into a preallocated
(token_nums, index_topk)buffer for all tokens.
- Add
-
Fix a shape-mismatch bug in NSA sparse prefill:
when MLP-sync pads tokens to TP multiples (7→8), the indexer still returns only real-token rows.
_pad_topk_indices(...)allow padding to maintain#tokens(q) == #tokens(topk_indices), restoring correctness for FlashMLA-sparse under partial DP attention (e.g., TP 8, DP 4).
Accuracy Tests
Launch with
python -m sglang.launch_server --model deepseek-ai/DeepSeek-V3.2-Exp --tp 8
GPQA:
Repeat: 8, mean: 0.799 | 26/198 [36:38<1:39:53, 34.84s/it]
Scores: ['0.778', '0.798', '0.823', '0.803', '0.778', '0.803', '0.788', '0.823']
====================█▎ | 26/198 [37:46<6:15:40, 131.05s/it]
Writing report to /tmp/gpqa_deepseek-ai_DeepSeek-V3.2-Exp.html | 48/198 [37:35<2:00:18, 48.12s/it]
{'chars': np.float64(14523.828282828283), 'chars:std': np.float64(11247.561643307768), 'score:std': np.float64(0.38147197173296354), 'score': np.float64(0.8232323232323232)}
Writing results to /tmp/gpqa_deepseek-ai_DeepSeek-V3.2-Exp.json
Total latency: 2267.919 s
Score: 0.823
Benchmarking and Profiling
Checklist
- [x] Format your code according to the Format code with pre-commit.
- [ ] Add unit tests according to the Run and add unit tests.
- [ ] Update documentation according to Write documentations.
- [ ] Provide accuracy and speed benchmark results according to Test the accuracy and Benchmark the speed.
- [ ] Follow the SGLang code style guidance.
- [ ] Work with maintainers to merge your PR. See the PR Merge Process
Summary of Changes
Hello @YAMY1234, I'm Gemini Code Assist[^1]! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces support for pure Tensor Parallelism (TP) for the DeepSeekV3.2 model's Non-Standard Attention (NSA) mechanism. It addresses compatibility issues with the FlashMLA sparse kernel when num_heads are partitioned across devices in a TP setup, by dynamically padding and unpadding query tensors. Additionally, it configures the server arguments to allow a dedicated pure TP execution mode, optimizing resource utilization for this specific model.
Highlights
- FlashMLA Sparse Kernel Compatibility: Implemented padding and unpadding logic for query tensors in
nsa_backend.pyto ensurenum_headsaligns with FlashMLA sparse kernel requirements (multiples of 64/128 for Hopper/Blackwell GPUs), especially when Tensor Parallelism reduces the number of heads per device. - Pure Tensor Parallelism (TP) for DeepSeek NSA: Modified
server_args.pyto enable a pure TP mode for DeepSeek NSA models. Ifdp_sizeis not explicitly set, it defaults to 1, allowing attention weights to be sharded across TP ranks without data parallelism.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in pull request comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with :thumbsup: and :thumbsdown: on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
[^1]: Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.
Thanks @YAMY1234~ If your PR is blocked on FlashMLA side, you can create a new branch at https://github.com/sgl-project/FlashMLA. The flashmla kernel now integrated in sglang are built on this repo
In _forward_flashmla_sparse(...), pad q’s head dimension to the required multiple (64 on SM90, 128 on SM100+)
@YAMY1234 For the Hxx device, the padding head may exhibit poor performance. Could we consider swapping the head and token dimensions via all-to-all (a2a) communication, where each rank increases the head dimension while reducing the token dimension?
In _forward_flashmla_sparse(...), pad q’s head dimension to the required multiple (64 on SM90, 128 on SM100+)
@YAMY1234 For the Hxx device, the padding head may exhibit poor performance. Could we consider swapping the head and token dimensions via all-to-all (a2a) communication, where each rank increases the head dimension while reducing the token dimension?
Thanks for the suggestion!
For now, this padding logic only applies to the FlashMLA sparse path under pure TP / partial DP attention. For original normal DP attention path is unchanged — if the layout doesn’t match, it will just fail as before, so there should be no behavioral change there.
I agree that doing an a2a-based head/token swap could be a good follow-up optimization for the TP / partial DP case on Hxx to reduce the padding overhead. For this PR, I think we can keep the change scoped to making pure TP & partial DP attention functional and stable, and we can explore the a2a approach in a separate perf-focused change.
@YAMY1234 Hi, I use your branch (https://github.com/YAMY1234/sglang/tree/dpsk_tp) and get some error in PD:
[2025-11-25 11:34:35 DP0 TP2] Scheduler hit an exception: Traceback (most recent call last):
File "sglang/python/sglang/srt/managers/scheduler.py", line 2659, in run_scheduler_process
scheduler.event_loop_overlap_disagg_prefill()
File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "sglang/python/sglang/srt/disaggregation/prefill.py", line 369, in event_loop_overlap_disagg_prefill
batch_result = self.run_batch(batch)
^^^^^^^^^^^^^^^^^^^^^
File "sglang/python/sglang/srt/managers/scheduler.py", line 1952, in run_batch
batch_result = self.model_worker.forward_batch_generation(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "sglang/python/sglang/srt/managers/tp_worker.py", line 371, in forward_batch_generation
logits_output, can_run_cuda_graph = self.model_runner.forward(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "sglang/python/sglang/srt/model_executor/model_runner.py", line 2277, in forward
output = self._forward_raw(
^^^^^^^^^^^^^^^^^^
File "sglang/python/sglang/srt/model_executor/model_runner.py", line 2336, in _forward_raw
ret = self.forward_extend(
^^^^^^^^^^^^^^^^^^^^
File "sglang/python/sglang/srt/model_executor/model_runner.py", line 2222, in forward_extend
return self.model.forward(
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "sglang/python/sglang/srt/models/deepseek_v2.py", line 3504, in forward
hidden_states = self.model(
^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "sglang/python/sglang/srt/models/deepseek_v2.py", line 3315, in forward
hidden_states, residual = layer(
^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "sglang/python/sglang/srt/models/deepseek_v2.py", line 3028, in forward
hidden_states = self.self_attn(
^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "sglang/python/sglang/srt/models/deepseek_v2.py", line 1498, in forward
return self.forward_core(s)
^^^^^^^^^^^^^^^^^^^^
File "sglang/python/sglang/srt/models/deepseek_v2.py", line 1597, in forward_core
return self.forward_absorb_core(*inner_state)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "sglang/python/sglang/srt/models/deepseek_v2.py", line 1975, in forward_absorb_core
attn_output = self.attn_mqa(
^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "sglang/python/sglang/srt/layers/radix_attention.py", line 123, in forward
return forward_batch.attn_backend.forward(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "sglang/python/sglang/srt/layers/attention/base_attn_backend.py", line 101, in forward
return self.forward_extend(
^^^^^^^^^^^^^^^^^^^^
File "sglang/python/sglang/srt/layers/attention/nsa_backend.py", line 998, in forward_extend
return self._forward_flashmla_kv(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "sglang/python/sglang/srt/layers/attention/nsa_backend.py", line 1261, in _forward_flashmla_kv
o, _ = flash_mla_with_kvcache(
^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/sgl_kernel/flash_mla.py", line 112, in flash_mla_with_kvcache
out, softmax_lse = torch.ops.sgl_kernel.fwd_kvcache_mla.default(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 829, in __call__
return self._op(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: seqlens_k must have shape (batch_size)
I launced prefill as
export LWS_LEADER_ADDRESS=xxx
export LWS_GROUP_SIZE=2
export LWS_WORKER_INDEX=0
export SGLANG_DG_CACHE_DIR=/data1/models/deepgemm_cache/deepseek_v31
SGLANG_TBO_DEBUG=0 \
python3 -m sglang.launch_server \
--model-path /data1/models/DeepSeek-V3.2-Exp \
--served-model-name DeepSeek-V3.2-Exp \
--disaggregation-mode prefill \
--disaggregation-transfer-backend mooncake \
--disaggregation-ib-device "mlx5_1,mlx5_2,mlx5_3,mlx5_4" \
--dist-init-addr ${LWS_LEADER_ADDRESS}:20000 --nnodes ${LWS_GROUP_SIZE} --node-rank ${LWS_WORKER_INDEX} \
--tp 16 \
--dp 2 \
--enable-dp-attention \
--mem-fraction-static 0.80 \
--chunked-prefill-size 32768 \
--context-length 131072 \
--trust-remote-code \
--host 0.0.0.0 \
--port 12121 \
--log-level debug \
--enable-cache-report \
--enable-metrics \
--kv-cache-dtype fp8_e4m3 \
--reasoning-parser deepseek-v3 \
--tool-call-parser deepseekv31 \
--chat-template /sgl-workspace/sglang/examples/chat_template/tool_chat_template_deepseekv31.jinja \
--prefill-round-robin-balance \
--page-size 64 \
--decode-log-interval 1
decode as TP16 DP16 EP16 Can you solve this? Thank you very much.
@YAMY1234 Hi, I use your branch (https://github.com/YAMY1234/sglang/tree/dpsk_tp) and get some error in PD:
Thanks for pointing this out!
For now this PR is mainly focused on and validated under the aggregated (agg) setup🥺. In disaggregated (disagg) mode, especially with partial-DP + disagg prefill, we haven’t fully verified all execution paths yet. We will need to add more fixes for disagg to make the partial-DP path stable.
@YAMY1234 Can you add a benchmark for bs=1? Expectedly pure TP should be faster than DP+TP
@YAMY1234 Can you add a benchmark for bs=1? Expectedly pure TP should be faster than DP+TP
@Fridge003 Added in the PR description~
@YAMY1234 Can you add a benchmark for bs=1? Expectedly pure TP should be faster than DP+TP
@Fridge003 Added in the PR description~
Oh I mean performance benchmark. You can test with python3 -m sglang.test.send_one
@YAMY1234 Can you add a benchmark for bs=1? Expectedly pure TP should be faster than DP+TP
@Fridge003 Added in the PR description~
Oh I mean performance benchmark. You can test with
python3 -m sglang.test.send_one
Sorry my previous added benchmark disappeared😂 might be an saving error. I tested with sglang.bench_serving and looks like TP will be 4-5x faster than DP Attention during prefill in bs=1's situation. Could you take a second look at the PR desc? Thanks!
@YAMY1234 Thanks~
Since this PR will break the usage of deepseek v32, can you please change all the related usage (appending --dp argument) in the test cases (files with prefix test_deepseek_v32).
A pure tp test can be added under test/nightly folder. The document also needs update
@Fridge003 Thanks! Added docs and unittest~
/tag-and-rerun-ci