Support MLA in Torch Native Attention Backend
Motivation
Modifications
Checklist
- [ ] Format your code according to the Code Formatting with Pre-Commit.
- [ ] Add unit tests as outlined in the Running Unit Tests.
- [ ] Update documentation / docstrings / example tutorials as needed, according to Writing Documentation.
- [ ] Provide throughput / latency benchmark results and accuracy evaluation results as needed, according to Benchmark and Profiling and Accuracy Results.
- [ ] For reviewers: If you haven't made any contributions to this PR and are only assisting with merging the main branch, please remove yourself as a co-author when merging the PR.
Hi @ispobock , could you help to review?
Could you fix the pr test and provide some benchmark data vs previous version?
Hi @ispobock, the failed test seems to be unrelated to this PR change. Is there any way to retrigger failed test to avoid flaky error?
FAIL: test_gsm8k (__main__.TestW8A8)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/public_sglang_ci/runner-b-gpu-67/_work/sglang/sglang/test/srt/test_w8a8_quantization.py", line 45, in test_gsm8k
self.assertGreater(metrics["accuracy"], 0.7)
AssertionError: 0.68 not greater than 0.7
The performance comparison by using TestTorchNativeAttnBackend::test_latency
- before change:
Warmup ...
Prefill. latency: 0.09715 s, throughput: 1317.49 token/s
Decode. latency: 0.02727 s, throughput: 36.67 token/s
Decode. latency: 0.02193 s, throughput: 45.60 token/s
Decode. latency: 0.02186 s, throughput: 45.74 token/s
Decode. latency: 0.02186 s, throughput: 45.74 token/s
Decode. latency: 0.02188 s, throughput: 45.70 token/s
Decode. median latency: 0.02188 s, median throughput: 45.70 token/s
Total. latency: 0.256 s, throughput: 531.85 token/s
Benchmark ...
Prefill. latency: 0.02665 s, throughput: 4803.48 token/s
Decode. latency: 0.02181 s, throughput: 45.86 token/s
Decode. latency: 0.02164 s, throughput: 46.22 token/s
Decode. latency: 0.02168 s, throughput: 46.12 token/s
Decode. latency: 0.02165 s, throughput: 46.18 token/s
Decode. latency: 0.02164 s, throughput: 46.22 token/s
Decode. median latency: 0.02168 s, median throughput: 46.12 token/s
Total. latency: 0.179 s, throughput: 761.47 token/s
- after change:
Warmup ...
Prefill. latency: 0.09651 s, throughput: 1326.30 token/s
Decode. latency: 0.90841 s, throughput: 1.10 token/s
Decode. latency: 0.02256 s, throughput: 44.32 token/s
Decode. latency: 0.02215 s, throughput: 45.14 token/s
Decode. latency: 0.02197 s, throughput: 45.52 token/s
Decode. latency: 0.02222 s, throughput: 45.00 token/s
Decode. median latency: 0.02222 s, median throughput: 45.00 token/s
Total. latency: 1.138 s, throughput: 119.49 token/s
Benchmark ...
Prefill. latency: 0.02040 s, throughput: 6275.52 token/s
Decode. latency: 0.02217 s, throughput: 45.11 token/s
Decode. latency: 0.02213 s, throughput: 45.19 token/s
Decode. latency: 0.02201 s, throughput: 45.43 token/s
Decode. latency: 0.02194 s, throughput: 45.57 token/s
Decode. latency: 0.02206 s, throughput: 45.34 token/s
Decode. median latency: 0.02206 s, median throughput: 45.34 token/s
Total. latency: 0.175 s, throughput: 778.87 token/s
It seems that the decode perf is impacted, I will investigate it.
Compared to latest main branch, the decode perf has no obvious gap, the prefill perf improved.
- main branch:
Warmup ...
Prefill. latency: 0.10436 s, throughput: 1226.48 token/s
Decode. latency: 0.92203 s, throughput: 1.08 token/s
Decode. latency: 0.02368 s, throughput: 42.22 token/s
Decode. latency: 0.02328 s, throughput: 42.96 token/s
Decode. latency: 0.02306 s, throughput: 43.37 token/s
Decode. latency: 0.02324 s, throughput: 43.02 token/s
Decode. median latency: 0.02328 s, median throughput: 42.96 token/s
Total. latency: 1.166 s, throughput: 116.63 token/s
Benchmark ...
Prefill. latency: 0.02907 s, throughput: 4403.07 token/s
Decode. latency: 0.02334 s, throughput: 42.84 token/s
Decode. latency: 0.02322 s, throughput: 43.06 token/s
Decode. latency: 0.02301 s, throughput: 43.47 token/s
Decode. latency: 0.02306 s, throughput: 43.36 token/s
Decode. latency: 0.02321 s, throughput: 43.09 token/s
Decode. median latency: 0.02320 s, median throughput: 43.10 token/s
Total. latency: 0.191 s, throughput: 711.22 token/s
- this PR:
Warmup ...
Prefill. latency: 0.11434 s, throughput: 1119.49 token/s
Decode. latency: 0.91263 s, throughput: 1.10 token/s
Decode. latency: 0.02263 s, throughput: 44.18 token/s
Decode. latency: 0.02238 s, throughput: 44.69 token/s
Decode. latency: 0.02231 s, throughput: 44.81 token/s
Decode. latency: 0.02233 s, throughput: 44.78 token/s
Decode. median latency: 0.02238 s, median throughput: 44.69 token/s
Total. latency: 1.161 s, throughput: 117.10 token/s
Benchmark ...
Prefill. latency: 0.02052 s, throughput: 6238.77 token/s
Decode. latency: 0.02227 s, throughput: 44.91 token/s
Decode. latency: 0.02228 s, throughput: 44.88 token/s
Decode. latency: 0.02226 s, throughput: 44.92 token/s
Decode. latency: 0.02230 s, throughput: 44.84 token/s
Decode. latency: 0.02238 s, throughput: 44.68 token/s
Decode. median latency: 0.02227 s, median throughput: 44.90 token/s
Total. latency: 0.177 s, throughput: 770.39 token/s
@ispobock
Hi @YangQun1, I reviewed this PR but not sure why this change is related to MLA?
Hi @YangQun1, I reviewed this PR but not sure why this change is related to MLA?
With this PR, we can run DeepSeek-V2-Lite model with torch native backend while not setting --disable-mla flag.
With this PR, we can run DeepSeek-V2-Lite model with torch native backend while not setting --disable-mla flag.
Got it. This change is mainly for the forward_normal part, the kv is different from the kv cache.
Hi @ispobock , ci tests passed, could you help to merge?
cc: @zhyncs pls help merge
Hi @zhyncs , could you help to take a look at the CI failure? it looks like repo access issue.
raise FileNotFoundError(msg) from err
FileNotFoundError: sgl-project/sglang-ci-dsv3-test (repository not found)
Hi @YangQun1, is there any update of this support MLA for torch_native attention backend? Could you please also take CPU into account? Thanks!
cc @mingfeima @chunyuan-w
Hi @zhyncs, could you please also take a look of this PR? Thanks!
@YangQun1 3ks~
I will also test more models on local machines for torch_native backend.
fixed known issues for cpu device. Thanks @yanbing-j for help.
Hi @zhyncs , could you help to trigger CI?
I will also test more models on local machines for torch_native backend.
Hi @zhyncs ,
With --device cpu --attention-backend torch_native, V2-lite can run now. R1 FP8 and R1 INT8 can also run with https://github.com/sgl-project/sglang/pull/7820.
Now, CI only has V2-lite for CPU. So this PR is okay to merge. And we are expanding model scope to include R1 inside.
@YangQun1 Please go through all CI failures and fix if it is our business.
re-trigger CI.
@YangQun1 @yanbing-j can we take a look the CI failures? seems to be able to reproduce.
Hi @zhyncs ,
I checked https://github.com/sgl-project/sglang/actions/runs/16434788523/job/46443992313?pr=3475, which will fail in mi325 and pass in mi300. I tried to test in local mi300 with cmdline in https://docs.sglang.ai/start/install.html#method-2-from-source, and UT indeed passes in mi300. Could you please give some advices about this failure in mi325? Thanks!
Other CI failuers are not related to the code change of this PR.
Hi @zhyncs ,
I checked https://github.com/sgl-project/sglang/actions/runs/16434788523/job/46443992313?pr=3475, which will fail in mi325 and pass in mi300. I tried to test in local mi300 with cmdline in https://docs.sglang.ai/start/install.html#method-2-from-source, and UT indeed passes in mi300. Could you please give some advices about this failure in mi325? Thanks!
Other CI failuers are not related to the code change of this PR.
Hi @zhyncs ,
I find that https://github.com/sgl-project/sglang/actions/runs/16434788523/job/46443992313?pr=3475 is same as https://github.com/sgl-project/sglang/actions/runs/16465853506/job/46604886976?pr=8212. Is this PR okay to merge? Could you please take a look? Thanks!
Any idea why the AMD CIs failed? is this related to this PR?
Any idea why the AMD CIs failed? is this related to this PR?
AMC CI failures are not the business of this PR, which can be observed in other PR CI.
I do the rebase again, and find that https://github.com/sgl-project/sglang/pull/8416 will break CI of V2-lite for CPU due to transformers version from 4.53.2 to 4.54.0. I revert the version back to 4.53.2, CI can pass.
@zhyncs Could you please take a look? Thanks!
Traceback (most recent call last):
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/home/yanbingj/project/sglang/python/sglang/bench_one_batch.py", line 59, in <module>
from sglang.srt.configs.model_config import ModelConfig
File "/home/yanbingj/project/sglang/python/sglang/srt/configs/model_config.py", line 31, in <module>
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
File "/home/yanbingj/project/sglang/python/sglang/srt/layers/quantization/__init__.py", line 13, in <module>
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
File "/home/yanbingj/miniforge3/envs/sglang/lib/python3.12/site-packages/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py", line 16, in <module>
from vllm.model_executor.layers.fused_moe import (
File "/home/yanbingj/miniforge3/envs/sglang/lib/python3.12/site-packages/vllm/model_executor/layers/fused_moe/__init__.py", line 7, in <module>
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
File "/home/yanbingj/miniforge3/envs/sglang/lib/python3.12/site-packages/vllm/model_executor/layers/fused_moe/config.py", line 12, in <module>
from vllm.config import ParallelConfig
File "/home/yanbingj/miniforge3/envs/sglang/lib/python3.12/site-packages/vllm/config.py", line 37, in <module>
from vllm.transformers_utils.config import (
File "/home/yanbingj/miniforge3/envs/sglang/lib/python3.12/site-packages/vllm/transformers_utils/config.py", line 33, in <module>
from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
File "/home/yanbingj/miniforge3/envs/sglang/lib/python3.12/site-packages/vllm/transformers_utils/configs/__init__.py", line 26, in <module>
from vllm.transformers_utils.configs.ovis import OvisConfig
File "/home/yanbingj/miniforge3/envs/sglang/lib/python3.12/site-packages/vllm/transformers_utils/configs/ovis.py", line 76, in <module>
AutoConfig.register("aimv2", AIMv2Config)
File "/home/yanbingj/miniforge3/envs/sglang/lib/python3.12/site-packages/transformers/models/auto/configuration_auto.py", line 1306, in register
CONFIG_MAPPING.register(model_type, config, exist_ok=exist_ok)
File "/home/yanbingj/miniforge3/envs/sglang/lib/python3.12/site-packages/transformers/models/auto/configuration_auto.py", line 993, in register
raise ValueError(f"'{key}' is already used by a Transformers config, pick another name.")
ValueError: 'aimv2' is already used by a Transformers config, pick another name.
The above failure is due to "All vLLM <= v0.10.0 and transformers>=4.54.0 will encounter this issue, the vLLM fix this issue on vllm-project/vllm@3fc9644".
Upgrading vllm to v0.10.0 or removing vllm dependency, can pass CPU CI.
CI failure of https://github.com/sgl-project/sglang/actions/runs/16562758648/job/46836306570?pr=3475 is same as https://github.com/sgl-project/sglang/actions/runs/16322800115/job/46104852281, which should be fixed in https://github.com/sgl-project/sglang/pull/8461.
@Alcanderian I check the CI failures, which are not raised by this PR. Could you please help merge this PR? Thanks!
@Alcanderian I check the CI failures, which are not raised by this PR. Could you please help merge this PR? Thanks!
This failed is definitely raise by this pr because the interface is changed https://github.com/sgl-project/sglang/actions/runs/16819663128/job/47644040981?pr=3475#step:5:736
@Alcanderian I check the CI failures, which are not raised by this PR. Could you please help merge this PR? Thanks!
This failed is definitely raise by this pr because the interface is changed https://github.com/sgl-project/sglang/actions/runs/16819663128/job/47644040981?pr=3475#step:5:736
@YangQun1 Could you please take a look at the CI failure?