lmdeploy
lmdeploy copied to clipboard
Turbomind prefix caching
Motivation
https://github.com/InternLM/lmdeploy/issues/1407
Modification
- [x] Turbomind change
- [ ] Add cli option after https://github.com/InternLM/lmdeploy/pull/1429 merged
- [x] Benchmark and evaluation
- [x] Compatibility testing with AWQ, online KV Cache Int4/Int8 and tp
TODO: need to add compatibility testing with AWQ, online KV Cache Int4 and Int8 @ispobock
TODO: need to add compatibility testing with AWQ, online KV Cache Int4 and Int8 @ispobock
Also need to test the case when TP is turned on.
Benchmark with method mentioned in https://github.com/InternLM/lmdeploy/issues/1407#issuecomment-2046702194. Settings:
engine: Turbomind
model: llama2-13B-chat
num prompts: 1000
Use LMDeploy benchmark script (used in https://github.com/InternLM/lmdeploy/pull/1429#issuecomment-2063156779): w/o prefix caching:
concurrency: 128
elapsed_time: 168.270s
number of prompt tokens: 332115
number of completion tokens: 241536
token throughput (completion token): 1435.405 token/s
token throughput (prompt + completion token): 3409.104 token/s
RPS (request per second): 5.943 req/s
RPM (request per minute): 356.569 req/min
with prefix caching:
concurrency: 128
elapsed_time: 146.064s
number of prompt tokens: 332115
number of completion tokens: 241536
token throughput (completion token): 1653.630 token/s
token throughput (prompt + completion token): 3927.392 token/s
RPS (request per second): 6.846 req/s
RPM (request per minute): 410.779 req/min
Use vLLM benchmark script (used in https://github.com/InternLM/lmdeploy/issues/1407#issuecomment-2046702194): w/o prefix caching:
============ Serving Benchmark Result ============
Successful requests: 1000
Benchmark duration (s): 112.31
Total input tokens: 336509
Total generated tokens: 160192
Request throughput (req/s): 8.90
Input token throughput (tok/s): 2996.25
Output token throughput (tok/s): 1426.34
---------------Time to First Token----------------
Mean TTFT (ms): 39691.85
Median TTFT (ms): 35011.06
P99 TTFT (ms): 101250.62
with prefix caching:
============ Serving Benchmark Result ============
Successful requests: 1000
Benchmark duration (s): 97.21
Total input tokens: 336509
Total generated tokens: 160178
Request throughput (req/s): 10.29
Input token throughput (tok/s): 3461.66
Output token throughput (tok/s): 1647.75
---------------Time to First Token----------------
Mean TTFT (ms): 33815.43
Median TTFT (ms): 31043.85
P99 TTFT (ms): 86314.32
We can see almost 15% throughput improvement when enable prefix caching for Turbomind engine. Actually, the token_id length of system prompts added in https://github.com/InternLM/lmdeploy/issues/1407#issuecomment-2046702194 is 116, which means only 1 block will be reused. The improvement will be more significant when using longer system prompts.
Evaluation result for Internlm2-7b with prefix caching:
dataset version metric mode internlm2-7b-turbomind
-------------------------------------- --------- ------------- ------ ------------------------
--------- 考试 Exam --------- - - - -
ceval - naive_average gen 64.29
agieval - - - -
mmlu - naive_average gen 62.46
GaokaoBench - - - -
ARC-c - - - -
--------- 语言 Language --------- - - - -
WiC d06864 accuracy gen 56.43
summedits - - - -
chid-dev - - - -
afqmc-dev - - - -
bustm-dev - - - -
cluewsc-dev - - - -
WSC 7902a7 accuracy gen 53.85
winogrande - - - -
flores_100 - - - -
--------- 知识 Knowledge --------- - - - -
BoolQ - - - -
commonsense_qa - - - -
nq - - - -
triviaqa 2121ce score gen 61.45
--------- 推理 Reasoning --------- - - - -
cmnli - - - -
ocnli - - - -
ocnli_fc-dev - - - -
AX_b - - - -
AX_g - - - -
CB - - - -
RTE - - - -
story_cloze - - - -
COPA - - - -
ReCoRD - - - -
hellaswag - - - -
piqa - - - -
siqa - - - -
strategyqa - - - -
math - - - -
gsm8k 1d7fe4 accuracy gen 71.19
TheoremQA - - - -
openai_humaneval - - - -
mbpp - - - -
bbh - - - -
--------- 理解 Understanding --------- - - - -
C3 - - - -
CMRC_dev - - - -
DRCD_dev - - - -
MultiRC - - - -
race-middle 9a54b6 accuracy gen 22.77
race-high 9a54b6 accuracy gen 22.53
openbookqa_fact - - - -
csl_dev - - - -
lcsts - - - -
Xsum - - - -
eprstmt-dev - - - -
lambada - - - -
tnews-dev - - - -
Fantastic job! Thanks so much. We plan to release v0.4.0 next Tuesday, mainly focusing on new VLMs support and kv4/8 quantization and inference. Regarding the prefix caching of both engines, I would like to highlight it in v0.5.0, which is planned to be published around May 20th
Fantastic job! Thanks so much. We plan to release v0.4.0 next Tuesday, mainly focusing on new VLMs support and kv4/8 quantization and inference. Regarding the prefix caching of both engines, I would like to highlight it in v0.5.0, which is planned to be published around May 20th
ok
And may you help review the code and give some suggestions? Thanks. @lvhan028 @lzhangzz @grimoire
We plan to release v0.4.0 next Tuesday, mainly focusing on new VLMs support and kv4/8 quantization and inference. Regarding the prefix caching of both engines, I would like to highlight it in v0.5.0, which is planned to be published around May 20th
@lvhan028 Is there any planned features on Turbomind engine in the next month? Hopefully there won't be too many code conflicts.
We plan to release v0.4.0 next Tuesday, mainly focusing on new VLMs support and kv4/8 quantization and inference. Regarding the prefix caching of both engines, I would like to highlight it in v0.5.0, which is planned to be published around May 20th
@lvhan028 Is there any planned features on Turbomind engine in the next month? Hopefully there won't be too many code conflicts.
especially in LlamaBatch
There are definitely conflicts due to #1458
There are definitely conflicts due to #1458
There is almost no impact, as long as the refactoring of LlamaBatch, decoupling batch and model will be after v0.5.0, there will not be any major impact.
There are definitely conflicts due to https://github.com/InternLM/lmdeploy/pull/1458
Got it. It seems no big conflict with this feature.
The evaluation result for Turbomind prefix caching + AWQ + online kv cache int4 + tp2:
dataset version metric mode internlm2-chat-7b-4bits-turbomind
-------------------------------------- --------- ------------- ------ ----------------------------------
--------- 考试 Exam --------- - - - -
ceval - naive_average gen 51.35
agieval - - - -
mmlu - naive_average gen 53.39
GaokaoBench - - - -
ARC-c - - - -
--------- 语言 Language --------- - - - -
WiC d06864 accuracy gen 52.19
summedits - - - -
chid-dev - - - -
afqmc-dev - - - -
bustm-dev - - - -
cluewsc-dev - - - -
WSC 7902a7 accuracy gen 63.46
winogrande - - - -
flores_100 - - - -
--------- 知识 Knowledge --------- - - - -
BoolQ - - - -
commonsense_qa - - - -
nq - - - -
triviaqa 2121ce score gen 40.64
--------- 推理 Reasoning --------- - - - -
cmnli - - - -
ocnli - - - -
ocnli_fc-dev - - - -
AX_b - - - -
AX_g - - - -
CB - - - -
RTE - - - -
story_cloze - - - -
COPA - - - -
ReCoRD - - - -
hellaswag - - - -
piqa - - - -
siqa - - - -
strategyqa - - - -
math - - - -
gsm8k 1d7fe4 accuracy gen 39.73
TheoremQA - - - -
openai_humaneval - - - -
mbpp - - - -
bbh - - - -
--------- 理解 Understanding --------- - - - -
C3 - - - -
CMRC_dev - - - -
DRCD_dev - - - -
MultiRC - - - -
race-middle 9a54b6 accuracy gen 74.16
race-high 9a54b6 accuracy gen 67.87
openbookqa_fact - - - -
csl_dev - - - -
lcsts - - - -
Xsum - - - -
eprstmt-dev - - - -
lambada - - - -
tnews-dev - - - -
The evaluation result for AWQ + online kv cache int4 + tp2, without Turbomind prefix caching:
dataset version metric mode internlm2-chat-7b-4bits-turbomind
-------------------------------------- --------- ------------- ------ -----------------------------------
--------- 考试 Exam --------- - - - -
ceval - naive_average gen 50.92
agieval - - - -
mmlu - naive_average gen 53.68
GaokaoBench - - - -
ARC-c - - - -
--------- 语言 Language --------- - - - -
WiC d06864 accuracy gen 53.29
summedits - - - -
chid-dev - - - -
afqmc-dev - - - -
bustm-dev - - - -
cluewsc-dev - - - -
WSC 7902a7 accuracy gen 67.31
winogrande - - - -
flores_100 - - - -
--------- 知识 Knowledge --------- - - - -
BoolQ - - - -
commonsense_qa - - - -
nq - - - -
triviaqa 2121ce score gen 40.48
--------- 推理 Reasoning --------- - - - -
cmnli - - - -
ocnli - - - -
ocnli_fc-dev - - - -
AX_b - - - -
AX_g - - - -
CB - - - -
RTE - - - -
story_cloze - - - -
COPA - - - -
ReCoRD - - - -
hellaswag - - - -
piqa - - - -
siqa - - - -
strategyqa - - - -
math - - - -
gsm8k 1d7fe4 accuracy gen 40.03
TheoremQA - - - -
openai_humaneval - - - -
mbpp - - - -
bbh - - - -
--------- 理解 Understanding --------- - - - -
C3 - - - -
CMRC_dev - - - -
DRCD_dev - - - -
MultiRC - - - -
race-middle 9a54b6 accuracy gen 74.30
race-high 9a54b6 accuracy gen 67.52
openbookqa_fact - - - -
csl_dev - - - -
lcsts - - - -
Xsum - - - -
eprstmt-dev - - - -
lambada - - - -
tnews-dev - - - -
The result diff is mainly caused by the sampling settings in the evaluation code. The results are close with and w/o prefix caching, which indicates these features are compatible.
Fantastic job! Thanks so much. We plan to release v0.4.0 next Tuesday, mainly focusing on new VLMs support and kv4/8 quantization and inference. Regarding the prefix caching of both engines, I would like to highlight it in v0.5.0, which is planned to be published around May 20th
We need to regularly merge the main branch before approving to avoid conflicts.
Also, in the current implementation, (re)-computation of shared blocks are not shared (even though the memory blocks are shared and may be re-written multiple times)
@lzhangzz In current implementation, the blocks in block trie are computed and read-only. We only cache and match computed blocks. So shared blocks will not be re-written multiple times.
In current implementation, the blocks in block trie are computed and read-only. We only cache and match computed blocks. So shared blocks will not be re-written multiple times.
I see.
However, once a cache block is evicted, the sharing that block seems problematic.
once a cache block is evicted, the sharing that block seems problematic.
@lzhangzz In the current setting, only blocks with use_count = 1 (only block trie holds the use_count) can be evicted. That means no sequence is using this block. After eviction, the block will be freed and reallocated.
In the current setting, only blocks with
use_count = 1(only block trie holds theuse_count) can be evicted. That means no sequence is using this block. After eviction, the block will be freed and reallocated.
This is also a problem. The idea of "occupied blocks can't be evicted" breaks the relaxed FCFS scheduling and may lead to starvation. It makes the ownership of shared blocks transferrable between requests (in an uncontrollable way).
Imagine the following request order. Assuming A_x are sharing a large chunk of blocks [b]
A_0 C A_1 A_2 A_3 ...
As more of A_x comes, C may never get a chance to run because [b] is always occupied.
And use_count is for active sequences, sequences in interactive mode can still reference the same invalid block. When re-computation happens later, the sequences will allocate & refill the shared blocks indenpendently.
The idea of "occupied blocks can't be evicted" breaks the relaxed FCFS scheduling and may lead to starvation.
@lzhangzz preempt logic can be applied to solve the starvation problem. C with higher priority, so the blocks in later A_x will be preempted.
preempt logic can be applied to solve the starvation problem. C with higher priority, so the blocks in later A_x will be preempted.
Preemption won't work because BlockTrie holds a dummy reference and BlockTrie::evict operates directly on the actual use_count while preemption logic operates on a snapshot of use_count.
We still have some un-closed issues
- With a batch of sequences sharing previously unseen (or evicted) prefixes, neither computation nor cache blocks are shared.
- When something in the prefix cache gets evicted, it can't get into the cache again until a request with the same prefix as prompt comes.
@lzhangzz
- The first one is expected. Now we only cache and reuse the computed blocks to avoid write conflict. In this design, the blocks reuse may have one iteration delay. If a new request want to match prefix blocks, the blocks should be cached in the previous iterations. I don't think it will affect overall performance too much.
- For the second one, I am not sure I fully understand. Blocks in the prefix cache will not be evicted when
use_count=0. They will only be evicted when they are re-allocated (indicated byunique_idmismatch during verification). Before the re-allocation, they still can be reused. If a block is re-allocated, we don't need to get it back to cache.
We still have some un-closed issues
With a batch of sequences sharing previously unseen (or evicted) prefixes, neither computation nor cache blocks are shared.
When something in the prefix cache gets evicted, it can't get into the cache again until a request with the same prefix as prompt comes.
@lzhangzz The original intention of this feature design is to solve the problem of system prefix cache. Especially for search scenarios in Internet companies, the current classic practice is to conduct SFT based on a SOTA chat model and use prompt engineering to make it a question and answer assistant in the vertical field. This usually means that in the request, there will usually be one or more common, similar system prefix caches. The current implementation is compatible with the existing designs and implementations of LMDeploy, and meets the requirements mentioned above. We currently believe it meets expectations. What you are discussing is a more universal cache. And we could also discuss your suggestion again. Regarding the first point, I also agree with what @ispobock said, as it does not have a significant impact on performance.
Hi @irexyc May you help review the PR? Thanks.
@irexyc Pls help to check this will not break VLM when embedding inputs are present.
Please add cc to the extensions and fix the format issue.
docker run -it --rm --workdir /src -v $(pwd):/src clang-format-lint --clang-format-executable /clang-format/clang-format11 -r --inplace True --style=file --extensions "h,c,cpp,hpp,cu,cuh,cc" src
modified: src/turbomind/models/llama/BlockTrie.cc
modified: src/turbomind/models/llama/LlamaBatch.cc
modified: src/turbomind/models/llama/SequenceManager.cc