sglang icon indicating copy to clipboard operation
sglang copied to clipboard

[DeepSeekV3.2] Centralize NSA dispatch logic in NativeSparseAttnBackend

Open YAMY1234 opened this issue 1 month ago • 3 comments

Motivation

NativeSparseAttnBackend currently spreads dispatch logic for NSA prefill/decode implementations and MHA vs. MLA selection across multiple places:

  • Global NSA_PREFILL_IMPL / NSA_DECODE_IMPL variables that are mutated in __init__.
  • MHA vs. MLA decisions partly in deepseek_v2.handle_attention_nsa() and partly in the backend.
  • FP8-specific dequantization flags (using_mha_one_shot_fp8_dequant) set from the model side.

This makes it harder to reason about which implementation is actually used for a given batch and backend configuration, and increases the risk of inconsistent behavior between normal and CUDA graph paths.

This PR centralizes the dispatching logic in NativeSparseAttnBackend so that:

  • All strategy decisions for a batch (MHA vs. MLA, backend implementation, FP8 dequant flag) are derived in one place.
  • The model code (deepseek_v2) simply reads the decision from the backend instead of re-implementing heuristics.
  • Backend configuration becomes instance-local instead of relying on module-level mutable globals.

Modifications

Dependency

This PR is stacked on top of the FP8 MHA support work in mha_fp8 (PR #12964).
Please review and merge that PR first; once merged, the diff here will contain only the central dispatch changes.

Centralized dispatch in NativeSparseAttnBackend

  • Remove module-level globals:
    • NSA_PREFILL_IMPL: _NSA_IMPL_T
    • NSA_DECODE_IMPL: _NSA_IMPL_T
  • Introduce instance-level fields instead:
    • self.nsa_prefill_impl: _NSA_IMPL_T
    • self.nsa_decode_impl: _NSA_IMPL_T
    • self.use_mha: bool = False — backend-owned flag indicating whether MHA_ONE_SHOT is selected for this batch.
  • In __init__:
    • Initialize self.nsa_prefill_impl and self.nsa_decode_impl from model_runner.server_args.
    • Derive self.enable_auto_select_prefill_impl from self.nsa_prefill_impl == "flashmla_auto".

Batch-level dispatch in init_forward_metadata

  • Restructure init_forward_metadata() to perform all strategy decisions up front:
    • Always call self.set_nsa_prefill_impl(forward_batch) as the central entry point for:
      • Choosing MHA vs. MLA.
      • Selecting the concrete NSA implementation (flashmla_sparse / flashmla_kv / fa3 / tilelang / aiter).
      • Updating self.use_mha and self.nsa_prefill_impl.
    • Use self.use_mha instead of a local will_use_mha variable to:
      • Drive FP8-specific dequantization logic.
      • Decide whether page table flattening is needed.
  • Update the MHA+FP8 dequantization check:
    • Compute mha_dequantize_needed as:
      • self.use_mha and forward_batch.token_to_kv_pool.dtype == torch.float8_e4m3fn.
    • Store the result in forward_batch.using_mha_one_shot_fp8_dequant so downstream kernels can rely on a consistent flag.

Backend-specific impl selection

  • Replace all checks of the form:
    • if NSA_PREFILL_IMPL == "..." and if NSA_DECODE_IMPL == "..."
    • with instance-local checks:
    • if self.nsa_prefill_impl == "..." and if self.nsa_decode_impl == "...".
  • Apply this consistently in:
    • init_forward_metadata
    • init_cuda_graph_state
    • init_forward_metadata_capture_cuda_graph
    • init_forward_metadata_replay_cuda_graph
    • forward_extend
    • forward_decode
  • Update error messages to reference self.nsa_prefill_impl / self.nsa_decode_impl so logs reflect the actual runtime configuration.

Top-k transform method and MHA flag

  • Simplify get_topk_transform_method() so it no longer depends on the removed global NSA_PREFILL_IMPL:
    • Use self.nsa_prefill_impl == "flashmla_sparse" when deciding whether to enable TopkTransformMethod.RAGGED.
  • Ensure self.use_mha is used consistently when:
    • Deciding whether FP8 MHA requires page_table_1_flattened.
    • Determining if MLA-specific metadata should be materialized.

Model-side changes in deepseek_v2

  • In handle_attention_nsa():
    • Update the docstring to reflect that dispatch logic is now fully centralized in NativeSparseAttnBackend.
    • Instead of calling backend.should_use_mha(forward_batch, attn):
      • Read the decision from backend.use_mha.
    • The function now returns:
      • AttnForwardMethod.MHA_ONE_SHOT if backend.use_mha is true.
      • AttnForwardMethod.MLA otherwise.
  • Remove the previous helper-based MHA decision, since the logic has moved into the backend.

Accuracy Tests

GPQA:

Repeat: 8, mean: 0.804█████████████████████████████████████████████████████████████████████████████████████████████▌                                  | 151/198 [23:49<03:30,  4.47s/it]
Scores: ['0.793', '0.823', '0.823', '0.783', '0.783', '0.803', '0.813', '0.813']
====================█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉        | 187/198 [24:16<00:35,  3.24s/it]
Writing report to /tmp/gpqa_deepseek-ai_DeepSeek-V3.2-Exp.html███████████████████████████████████████████████████████████████████████████████▉        | 187/198 [25:01<00:37,  3.39s/it]
{'chars': np.float64(14037.636363636364), 'chars:std': np.float64(11045.598489854758), 'score:std': np.float64(0.3898060809385348), 'score': np.float64(0.8131313131313131)}
Writing results to /tmp/gpqa_deepseek-ai_DeepSeek-V3.2-Exp.json
Total latency: 1411.931 s
Score: 0.813

GSM8K:

python3 benchmark/gsm8k/bench_sglang.py --num-shots 20 --num-questions 1319 --parallel 1319
Downloading from https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl to /tmp/test.jsonl
/tmp/test.jsonl: 732kB [00:00, 36.1MB/s]                                                                                                                                                
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [00:24<00:00, 53.49it/s]
Accuracy: 0.955
Invalid: 0.000
Latency: 25.342 s
Output throughput: 5132.798 token/s

Benchmarking and Profiling

YAMY1234 avatar Nov 18 '25 22:11 YAMY1234

[!WARNING] You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

gemini-code-assist[bot] avatar Nov 18 '25 22:11 gemini-code-assist[bot]

Please add accuracy tests results on B200 (gpqa, gsm8k 20shots)

hlu1 avatar Nov 22 '25 00:11 hlu1

Please add accuracy tests results on B200 (gpqa, gsm8k 20shots)

Updated in the pr description~

YAMY1234 avatar Nov 22 '25 23:11 YAMY1234

/tag-and-rerun-ci

Fridge003 avatar Nov 25 '25 06:11 Fridge003

All dpsk v3.2 tests passed https://github.com/sgl-project/sglang/actions/runs/19591766966/job/56375469704?pr=13544

Fridge003 avatar Nov 25 '25 19:11 Fridge003