[Draft] Torch deepseek v2
- MLA implementation hinted by https://kexue.fm/archives/10091 . kv share same cache blocks.
- support shared kv in paged attention to reduce smem usage.
q_a_proj,kv_a_proj_with_mqain attention layer,gatein moe layer are not distributed so less nccl op are required with the cost of memorys.- Each GPU takes a response for 20 experts (8 A100).
block_size=32would have better(more than double) performance than defaultblock_size=64.- large amount of runtime memories are required. big
cache_max_entry_countandmax_prefill_token_nummight leads to oom.
requirements:
- [x] #1520
- [ ] #1603
Hi @lvhan028 After this PR is ready and merged, will LMDeploy release a new release? Thanks. @grimoire
result of deepseek-v2-lite
Hi @grimoire Are the current performance benchmark results as expected, and how much of a leading advantage is there compared to vLLM? Thanks.
https://github.com/deepseek-ai/DeepSeek-V2?tab=readme-ov-file#inference-with-vllm-recommended
@zhyncs the latest profile result (256 concurrency, 3000 prompt, block_size=32 and --cache-max-entry-count has been adjusted to prevent OOM):
- deepseek v2: 3.627 req/s
- deepseek v2 lite: 10.404 req/s
Apart from the fact that the default value cannot be used for block_size, the rest is relatively acceptable.
We have not performed benchmarks on vLLM yet, 8 A100 are not always available (T T).
@grimoire ut TestMBGMV.test_mbgmv failed. may need to fix.
Hi @grimoire, I used your commit to run the workflow at https://github.com/zhyncs/lmdeploy/actions/runs/9584655537 and obtained the whl https://github.com/zhyncs/dl/releases/tag/0620. And I encountered an error https://github.com/triton-lang/triton/issues/4172. Do you have any ideas? Thanks!
triton has a prepackaged ptxas, which might be different with your cuda driver version. You can set your own ptxas (/path/to/cuda/bin/ptxas) with environment TRITON_PTXAS_PATH.
export TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas
It works for me. Thanks and cheers. @grimoire
- deepseek v2 lite: 10.404 req/s
Hi @grimoire May I ask if this uses a single A100 card or 8 cards? Thanks.
LMDeploy https://github.com/zhyncs/dl/releases/tag/0620 https://github.com/zhyncs/lmdeploy/actions/runs/9584655537 https://github.com/grimoire/lmdeploy/commit/995e0ed972a024445cfb9ae44b292e4907c4a798
single A100
# server
python3 -m lmdeploy serve api_server DeepSeek-V2-Lite --backend pytorch --cache-block-seq-len 32
# client
# https://github.com/vllm-project/vllm/blob/main/benchmarks/benchmark_serving.py
python3 benchmark_serving.py --backend lmdeploy --host 127.0.0.1 --port 23333 --dataset /workdir/ShareGPT_V3_unfiltered_cleaned_split.json --model DeepSeek-V2-Lite --tokenizer DeepSeek-V2-Lite --num-prompts 1000 --request-rate 128
result
# ignore_eos false
============ Serving Benchmark Result ============
Successful requests: 1000
Benchmark duration (s): 154.05
Total input tokens: 236142
Total generated tokens: 148682
Request throughput (req/s): 6.49
Input token throughput (tok/s): 1532.88
Output token throughput (tok/s): 965.14
---------------Time to First Token----------------
Mean TTFT (ms): 56583.14
Median TTFT (ms): 55727.01
P99 TTFT (ms): 113475.30
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 116.80
Median TPOT (ms): 90.45
P99 TPOT (ms): 475.46
---------------Inter-token Latency----------------
Mean ITL (ms): 77.64
Median ITL (ms): 58.83
P99 ITL (ms): 430.49
==================================================
# ignore_eos true
============ Serving Benchmark Result ============
Successful requests: 1000
Benchmark duration (s): 181.48
Total input tokens: 236142
Total generated tokens: 215605
Request throughput (req/s): 5.51
Input token throughput (tok/s): 1301.17
Output token throughput (tok/s): 1188.01
---------------Time to First Token----------------
Mean TTFT (ms): 65216.61
Median TTFT (ms): 65241.68
P99 TTFT (ms): 135946.04
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 96.65
Median TPOT (ms): 80.46
P99 TPOT (ms): 267.16
---------------Inter-token Latency----------------
Mean ITL (ms): 72.56
Median ITL (ms): 60.45
P99 ITL (ms): 372.03
==================================================
Hi @grimoire May I ask if this uses a single A100 card or 8 cards? Thanks.
It is profiled with single A100. The bottleneck of lite model is on the host side, TP would make it worse.
block_size=32would have better performance.
May we set the cache-block-seq-len to 32 by default when running DeepSeek V2 for inference? From the benchmark results, there is a significant performance gap.
# python3 benchmark_serving.py --backend lmdeploy --host 127.0.0.1 --port 23333 --dataset /workdir/ShareGPT_V3_unfiltered_cleaned_split.json --model DeepSeek-V2-Lite --tokenizer DeepSeek-V2-Lite --num-prompts 1000 --request-rate 128
# cache-block-seq-len 32, ignore_eos true
============ Serving Benchmark Result ============
Successful requests: 1000
Benchmark duration (s): 181.48
Total input tokens: 236142
Total generated tokens: 215605
Request throughput (req/s): 5.51
Input token throughput (tok/s): 1301.17
Output token throughput (tok/s): 1188.01
---------------Time to First Token----------------
Mean TTFT (ms): 65216.61
Median TTFT (ms): 65241.68
P99 TTFT (ms): 135946.04
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 96.65
Median TPOT (ms): 80.46
P99 TPOT (ms): 267.16
---------------Inter-token Latency----------------
Mean ITL (ms): 72.56
Median ITL (ms): 60.45
P99 ITL (ms): 372.03
==================================================
# cache-block-seq-len default, ignore_eos true
============ Serving Benchmark Result ============
Successful requests: 1000
Benchmark duration (s): 384.18
Total input tokens: 236142
Total generated tokens: 215594
Request throughput (req/s): 2.60
Input token throughput (tok/s): 614.67
Output token throughput (tok/s): 561.18
---------------Time to First Token----------------
Mean TTFT (ms): 155387.50
Median TTFT (ms): 153036.64
P99 TTFT (ms): 328194.63
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 196.80
Median TPOT (ms): 181.42
P99 TPOT (ms): 515.61
---------------Inter-token Latency----------------
Mean ITL (ms): 163.64
Median ITL (ms): 96.56
P99 ITL (ms): 1542.84
==================================================
LGTM
hold on plz. @grimoire @RunningLeon
python3 -m lmdeploy serve api_server /workdir/DeepSeek-V2-Lite-Chat --backend pytorch
# run multi times get different res when temperature 0
python3 benchmark/profile_restful_api.py 127.0.0.1:23333 /workdir/DeepSeek-V2-Lite-Chat /workdir/ShareGPT_V3_unfiltered_cleaned_split.json --model_name /workdir/DeepSeek-V2-Lite-Chat --num_prompts 1 --concurrency 1 --temperature 0
@zhyncs temperature=0 is a invalid value https://github.com/huggingface/transformers/blob/730a440734e1fb47c903c17e3231dac18e3e5fd6/src/transformers/generation/logits_process.py#L298
I set to 1 if temperature<=0 in pytorch engine. https://github.com/InternLM/lmdeploy/blob/9e8cb3c4948c0160f00d5401d4519b192aee6581/lmdeploy/pytorch/messages.py#L75
@grimoire If the temperature is 0, it is not supported. How can I get a deterministic answer?
TurboMind supports [0,2] for temperature
Just set topk=1 or given a small enough temperature https://github.com/huggingface/transformers/blob/730a440734e1fb47c903c17e3231dac18e3e5fd6/src/transformers/generation/logits_process.py#L275-L279
Note that small temperature might still leads to different result if two value in logits are close.