Prefix Caching
add prefix caching support
Section 1 (Basic Functionality):
- [x] Test on single request for llama model (no cpu swap)
- [x] Test on batches where there are requests with prefix and without prefix (no cpu swap)
- [x] Benchmark performance for batched analytics tasks (no cpu swap)
- [x] Alibi
- [ ] Test other models
- [x] Clean code
Future ToDo: Section 2 (Swapping and Auto Detection):
- [ ] Test CPU swap
- [ ] Optimize the prefix kernel
- [ ] Add prefix cache swap-in/out policy
- [ ] Better Interface & Auto prefix detection?
It seems that the prefix has not updated its physical block?
I tested on the meta-llama/Llama-2-70b-chat-hf and baichuan2-13b-chat, but it seems to have no acceleration effect.
Then I added a logger at the 358th row of worker.py to check if prefix updated its physics block.
logger's output is as follow: ------start generating------ prefix length: 592 block size: 16 Processed prompts: 0%| | 0/500 [00:00<?, ?it/s][36m(RayWorker pid=1138291)[0m INFO 12-05 10:46:31 worker.py:358] prefix_block_tables: [[], [], [], [], []] [36m(RayWorker pid=1138291)[0m INFO 12-05 10:46:36 worker.py:358] prefix_block_tables: [[], [], [], [], []][32m [repeated 44x across cluster][0m [36m(RayWorker pid=1138291)[0m INFO 12-05 10:46:42 worker.py:358] prefix_block_tables: [[], [], [], [], []][32m [repeated 44x across cluster][0m [36m(RayWorker pid=1138291)[0m INFO 12-05 10:46:47 worker.py:358] prefix_block_tables: [[], [], [], [], []][32m [repeated 44x across cluster][0m [36m(RayWorker pid=1138291)[0m INFO 12-05 10:46:52 worker.py:358] prefix_block_tables: [[], [], [], []] [36m(RayWorker pid=1138291)[0m INFO 12-05 10:46:52 worker.py:358] prefix_block_tables: [] [36m(RayWorker pid=1138290)[0m INFO 12-05 10:46:51 worker.py:358] prefix_block_tables: [[], [], [], [], []][32m [repeated 39x across cluster][0m Processed prompts: 0%| | 1/500 [00:23<3:18:23, 23.86s/it][36m(RayWorker pid=1138290)[0m INFO 12-05 10:46:52 worker.py:358] prefix_block_tables: [[], [], [], []][32m [repeated 3x across cluster][0m [36m(RayWorker pid=1138290)[0m INFO 12-05 10:46:55 worker.py:358] prefix_block_tables: [][32m [repeated 59x across cluster][0m [36m(RayWorker pid=1138291)[0m INFO 12-05 10:46:57 worker.py:358] prefix_block_tables: [[], [], [], [], []][32m [repeated 21x across cluster][0m [36m(RayWorker pid=1138291)[0m INFO 12-05 10:47:02 worker.py:358] prefix_block_tables: [[], [], [], [], []][32m [repeated 44x across cluster][0m [36m(RayWorker pid=1138291)[0m INFO 12-05 10:47:08 worker.py:358] prefix_block_tables: [[], [], [], [], []][32m [repeated 44x across cluster][0m [36m(RayWorker pid=1138291)[0m INFO 12-05 10:47:13 worker.py:358] prefix_block_tables: [[], [], [], [], []][32m [repeated 44x across cluster][0m [36m(RayWorker pid=1138291)[0m INFO 12-05 10:47:15 worker.py:358] prefix_block_tables: [[], [], []] [36m(RayWorker pid=1138291)[0m INFO 12-05 10:47:16 worker.py:358] prefix_block_tables: [] Processed prompts: 43%|████▎ | 217/500 [00:45<00:50, 5.65it/s][36m(RayWorker pid=1138291)[0m INFO 12-05 10:47:16 worker.py:358] prefix_block_tables: [[], []] Processed prompts: 44%|████▍ | 219/500 [00:45<00:49, 5.65it/s][36m(RayWorker pid=1138291)[0m INFO 12-05 10:47:17 worker.py:358] prefix_block_tables: [[]] [36m(RayWorker pid=1138290)[0m INFO 12-05 10:47:15 worker.py:358] prefix_block_tables: [[], [], [], [], []][32m [repeated 19x across cluster][0m Processed prompts: 44%|████▍ | 220/500 [00:47<00:53, 5.24it/s][36m(RayWorker pid=1138290)[0m INFO 12-05 10:47:15 worker.py:358] prefix_block_tables: [[], [], []][32m [repeated 3x across cluster][0m [36m(RayWorker pid=1138290)[0m INFO 12-05 10:47:18 worker.py:358] prefix_block_tables: [][32m [repeated 59x across cluster][0m [36m(RayWorker pid=1138290)[0m INFO 12-05 10:47:16 worker.py:358] prefix_block_tables: [[], []][32m [repeated 3x across cluster][0m [36m(RayWorker pid=1138290)[0m INFO 12-05 10:47:17 worker.py:358] prefix_block_tables: [[]][32m [repeated 3x across cluster][0m [36m(RayWorker pid=1138291)[0m INFO 12-05 10:47:23 worker.py:358] prefix_block_tables: [[], [], [], [], []][32m [repeated 41x across cluster][0m Processed prompts: 100%|██████████| 500/500 [00:54<00:00, 9.13it/s]cost time 56.1341028213501 saving output
[36m(RayWorker pid=1138291)[0m INFO 12-05 10:47:26 worker.py:358] prefix_block_tables: []
By the way, My prompts are generated by this script:
# generate test prompt
test_table = "|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||"
prompt_template = '''\
You are a helpful assistant in recongnizes the content of tables in markdown format. Here is a table as follows. You need to answer my question about the table.
# Table
{}
# Question
What' s the content in the ({},{}) cells
'''
with open('prompt.txt', 'w') as outer:
for row in range(50):
for column in range(10):
tmp_str = prompt_template.format(test_table, row + 1, column + 1)
tmp_str
# outer.write(f"{tmp_str}")
print(tmp_str.replace("\n", "\\n"), file=outer)
Maybe there are some bugs in my test?
@DouHappy Can you try with calling llm.generate() with one prompt first to warmup? prefix_block_tables=[[],[],[]] indicates that the kv cache for the prefix part hasn't been computed yet. Also, can you share your testing script?
@DouHappy Can you try with calling
llm.generate()with one prompt first to warmup?prefix_block_tables=[[],[],[]]indicates that the kv cache for the prefix part hasn't been computed yet. Also, can you share your testing script?
My test script:
# %%
# generate test prompt
test_table = "|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||"
prompt_template = '''\
You are a helpful assistant in recongnizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.
# Table
{}
# Question
What' s the content in the ({},{}) cells
'''
with open('prompt.txt', 'w') as outer:
for row in range(50):
for column in range(10):
tmp_str = prompt_template.format(test_table, row + 1, column + 1)
tmp_str
# outer.write(f"{tmp_str}")
print(tmp_str.replace("\n", "\\n"), file=outer)
# %%
import time
import datetime
import os
from vllm import LLM
from vllm import SamplingParams
import torch
def test_prefix(llm = None, sampling_params=None, prompts=None, prefix_len=None, save_file=None):
# set sampling_params
if sampling_params == None:
sampling_params = SamplingParams(temperature=0)
print("------start generating------")
start_time = time.time()
# whether use Prefix
if prefix_len != None:
print("warmup")
outputs = llm.generate(prompts[0], sampling_params=sampling_params, prefix_pos=[prefix_len])
# start inference
outputs = llm.generate(prompts, sampling_params=sampling_params, prefix_pos=[prefix_len] * len(prompts))
else:
outputs = llm.generate(prompts, sampling_params=sampling_params)
end_time = time.time()
print(f"cost time {end_time - start_time}")
if save_file != None:
print("saving output......")
for output in outputs:
print(output, file=save_file)
print(f"output saved in {save_file.name} {datetime.datetime.now()}")
# %%
# set gpus
os.environ['CUDA_VISIBLE_DEVICES']="0,1"
# init model and sampling parames
tensor_parallel_size = len(os.getenv('CUDA_VISIBLE_DEVICES').split(','))
# set baichuan model
model = "/data/images/llms/models--baichuan-inc--Baichuan2-13B-Chat"
# model = "/data/images/llms/chatdoc-llama-2-70b-chat-hf-checkpoint-8000"
# Create an LLM.
llm = LLM(model=model, tokenizer_mode='auto', trust_remote_code=True, tensor_parallel_size=tensor_parallel_size)
# %%
# get prompts
prompts = []
with open("prompt.txt", 'r') as reader:
prompts = reader.readlines()[:500]
with open("output.txt", 'w') as f:
test_prefix(llm=llm,
prompts=prompts[:50],
prefix_len=591,
save_file=f
)
I find a bug that when using two GPUs is slower than single GPU. Prefix‘s state 'on_gpu' is always False before prepare_inputs() When using two GPUs. and it works nice on single gpu. It mean multi_query_cached_kv_attention never be used when running on multi-gpus. My last test is also pass on single gpu.
one gpu with prefix
------start generating------
Processed prompts: 100%|██████████| 500/500 [00:22<00:00, 22.55it/s]cost time 23.27279233932495
saving output
one gpu without prefix
------start generating------
Processed prompts: 100%|██████████| 500/500 [01:21<00:00, 6.16it/s]cost time 82.11750793457031
saving output
but it cost about 60s on two gpus. Although It is very fast, I can't got the right output when I use prefix.
I propose the changes below to the PrefixPool class.
-
There is no need to maintain two data structures: one being a dict from
prefix_hashtoprefix_idand the other one being the list of prefixes. We can simply maintain just the dictionary, directly fromprefix_hashtoPrefix. -
I think that the prefix pool class should have a max_capacity, otherwise it can have the potential of growing without bounds. Maybe for the common use case the number of prefixes is bounded or limited. But there are some use cases where the number of prefixes can be unlimited. For example: consider a prompts to extract information from a document. Ideally we want to have the document text as a prefix, and then the rest of the prompt would include the specific information to extract. 2.1 Because of that I propose to implement it using an
OrderedDictso that we can easily implement a FIFO cache policy.
class PrefixPool:
"""Manages all the prompt prefixes.
Args:
block_size: The block size of the executed model.
Attributes:
prefixes: A list of all the prefixes.
prefixes_hash: Mapping from the hash of the prefix to the prefix id.
block_size: The block size of the executed model.
max_capacity: Max number of prefixes to keep in the pool at any given time.
"""
id_iter = itertools.count()
def __init__(
self,
block_size: int,
max_capacity: int = 32
) -> None:
# dict from hash to prefix. It is an OrderedDict so that
# we can easily implement a FIFO cache policy.
self.prefixes: dict[int, Prefix] = OrderedDict()
self.block_size = block_size # Number of tokens per memory block
self.max_capacity=max_capacity
def add_prefix(self, token_ids: list[int]) -> Prefix:
if len(self.prefixes) >= self.max_capacity:
self.prefixes.popitem(last=False)
prefix_hash = hash(tuple(token_ids))
assert prefix_hash not in self.prefixes
# generate prefix_id and create new prefix
prefix_id = next(self.id_iter)
prefix = Prefix(prefix_id, token_ids, self.block_size)
self.prefixes[prefix_hash] = prefix
return prefix
# use this first, if we already know from the application which part of the tokens are prefix.
def fixed_search(self, prefix_hash: int) -> Optional[Prefix]:
return self.prefixes.get(prefix_hash, None)
-
I am interested in contributing to this branch. I have opened a PR with some changes to it. Feel free to review it. https://github.com/caoshiyi/vllm/pull/2
-
The changes introduced in that PR do not really effectively add new functionality to the vllm engine right now, but they are stepping stones to make it happen in a later PR.
-
The idea that I have in mind is the following: Whenever the processing of a
SequenceGroupis either finished, or aborted, check theprefix_pool: PrefixPoolobject and see if theprefix(if any) of thatSequenceGroupstill exists in thePrefixPool. If it does not exists, then free the GPU memory from that prefix.
I have read the code base extensively and I don't think that adding this new functionality affects the behavior of the system as intended originally, given that by default PrefixPool will behave as if no max_capacity has been set.
@jadielam Thank you for your contribution! How about the following plan: I will fix the bugs of the PR and first merge a version without prefix pool capacity, and then you create a new PR in the vLLM repo to add max_capacity to prefix pool?
@jadielam Thank you for your contribution! How about the following plan: I will fix the bugs of the PR and first merge a version without prefix pool capacity, and then you create a new PR in the vLLM repo to add
max_capacityto prefix pool?
sounds good.
Thanks a lot for this great feature! I tried it with the latest caoshiyi:prefix, but I found that there's no speed improvement. (one V100 GPU, tested with Baichuan2-13B-chat model)
Hi @DouHappy , did you observe any speed improvement afterwards?
Thanks a lot for this great feature! I tried it with the latest caoshiyi:prefix, but I found that there's no speed improvement. (one V100 GPU, tested with Baichuan2-13B-chat model)
Hi @DouHappy , did you observe any speed improvement afterwards?
Yes,i got observe speed up. Could you should me your test script? Maybe you forgot warmup? BTW, I am trying to introduce prefix but only chinese version now. See this vLLM-prefix浅析(System Prompt,大模型推理加速) @franklyd
@franklyd @DouHappy There was a bug in my refactor. If you try now, you should be able to see speedups.
Thanks a lot for this great feature! I tried it with the latest caoshiyi:prefix, but I found that there's no speed improvement. (one V100 GPU, tested with Baichuan2-13B-chat model) Hi @DouHappy , did you observe any speed improvement afterwards?
Yes,i got observe speed up. Could you should me your test script? Maybe you forgot warmup? BTW, I am trying to introduce prefix but only chinese version now. See this vLLM-prefix浅析(System Prompt,大模型推理加速) @franklyd
Could you provide a test script for the speedup?
Could you provide a test script for the speedup?
+1
Hi @HaiShaw
Triton doesn't seem to support mixed precision dot product, so this kernel here fails if the k is uint8 and q is another precision. I've been trying to find a solution to this problem, but coming up with blanks. Do you have any ideas on how to approach this?
Hi @HaiShaw
Triton doesn't seem to support mixed precision dot product, so this kernel here fails if the
kis uint8 andqis another precision. I've been trying to find a solution to this problem, but coming up with blanks. Do you have any ideas on how to approach this?
Hi, @AlpinDale. Are you using prefix caching with FP8 KVCache? PyTorch and Triton used by vLLM could not support FP8 KVCache. Here are more information about prefix caching and FP8 KVCache in #3234.