vllm icon indicating copy to clipboard operation
vllm copied to clipboard

Prefix Caching

Open caoshiyi opened this issue 2 years ago • 3 comments

add prefix caching support

Section 1 (Basic Functionality):

  • [x] Test on single request for llama model (no cpu swap)
  • [x] Test on batches where there are requests with prefix and without prefix (no cpu swap)
  • [x] Benchmark performance for batched analytics tasks (no cpu swap)
  • [x] Alibi
  • [ ] Test other models
  • [x] Clean code

Future ToDo: Section 2 (Swapping and Auto Detection):

  • [ ] Test CPU swap
  • [ ] Optimize the prefix kernel
  • [ ] Add prefix cache swap-in/out policy
  • [ ] Better Interface & Auto prefix detection?

caoshiyi avatar Nov 15 '23 07:11 caoshiyi

It seems that the prefix has not updated its physical block?

I tested on the meta-llama/Llama-2-70b-chat-hf and baichuan2-13b-chat, but it seems to have no acceleration effect. Then I added a logger at the 358th row of worker.py to check if prefix updated its physics block. image

logger's output is as follow: ------start generating------ prefix length: 592 block size: 16 Processed prompts: 0%| | 0/500 [00:00<?, ?it/s][36m(RayWorker pid=1138291)[0m INFO 12-05 10:46:31 worker.py:358] prefix_block_tables: [[], [], [], [], []] [36m(RayWorker pid=1138291)[0m INFO 12-05 10:46:36 worker.py:358] prefix_block_tables: [[], [], [], [], []][32m [repeated 44x across cluster][0m [36m(RayWorker pid=1138291)[0m INFO 12-05 10:46:42 worker.py:358] prefix_block_tables: [[], [], [], [], []][32m [repeated 44x across cluster][0m [36m(RayWorker pid=1138291)[0m INFO 12-05 10:46:47 worker.py:358] prefix_block_tables: [[], [], [], [], []][32m [repeated 44x across cluster][0m [36m(RayWorker pid=1138291)[0m INFO 12-05 10:46:52 worker.py:358] prefix_block_tables: [[], [], [], []] [36m(RayWorker pid=1138291)[0m INFO 12-05 10:46:52 worker.py:358] prefix_block_tables: [] [36m(RayWorker pid=1138290)[0m INFO 12-05 10:46:51 worker.py:358] prefix_block_tables: [[], [], [], [], []][32m [repeated 39x across cluster][0m Processed prompts: 0%| | 1/500 [00:23<3:18:23, 23.86s/it][36m(RayWorker pid=1138290)[0m INFO 12-05 10:46:52 worker.py:358] prefix_block_tables: [[], [], [], []][32m [repeated 3x across cluster][0m [36m(RayWorker pid=1138290)[0m INFO 12-05 10:46:55 worker.py:358] prefix_block_tables: [][32m [repeated 59x across cluster][0m [36m(RayWorker pid=1138291)[0m INFO 12-05 10:46:57 worker.py:358] prefix_block_tables: [[], [], [], [], []][32m [repeated 21x across cluster][0m [36m(RayWorker pid=1138291)[0m INFO 12-05 10:47:02 worker.py:358] prefix_block_tables: [[], [], [], [], []][32m [repeated 44x across cluster][0m [36m(RayWorker pid=1138291)[0m INFO 12-05 10:47:08 worker.py:358] prefix_block_tables: [[], [], [], [], []][32m [repeated 44x across cluster][0m [36m(RayWorker pid=1138291)[0m INFO 12-05 10:47:13 worker.py:358] prefix_block_tables: [[], [], [], [], []][32m [repeated 44x across cluster][0m [36m(RayWorker pid=1138291)[0m INFO 12-05 10:47:15 worker.py:358] prefix_block_tables: [[], [], []] [36m(RayWorker pid=1138291)[0m INFO 12-05 10:47:16 worker.py:358] prefix_block_tables: [] Processed prompts: 43%|████▎ | 217/500 [00:45<00:50, 5.65it/s][36m(RayWorker pid=1138291)[0m INFO 12-05 10:47:16 worker.py:358] prefix_block_tables: [[], []] Processed prompts: 44%|████▍ | 219/500 [00:45<00:49, 5.65it/s][36m(RayWorker pid=1138291)[0m INFO 12-05 10:47:17 worker.py:358] prefix_block_tables: [[]] [36m(RayWorker pid=1138290)[0m INFO 12-05 10:47:15 worker.py:358] prefix_block_tables: [[], [], [], [], []][32m [repeated 19x across cluster][0m Processed prompts: 44%|████▍ | 220/500 [00:47<00:53, 5.24it/s][36m(RayWorker pid=1138290)[0m INFO 12-05 10:47:15 worker.py:358] prefix_block_tables: [[], [], []][32m [repeated 3x across cluster][0m [36m(RayWorker pid=1138290)[0m INFO 12-05 10:47:18 worker.py:358] prefix_block_tables: [][32m [repeated 59x across cluster][0m [36m(RayWorker pid=1138290)[0m INFO 12-05 10:47:16 worker.py:358] prefix_block_tables: [[], []][32m [repeated 3x across cluster][0m [36m(RayWorker pid=1138290)[0m INFO 12-05 10:47:17 worker.py:358] prefix_block_tables: [[]][32m [repeated 3x across cluster][0m [36m(RayWorker pid=1138291)[0m INFO 12-05 10:47:23 worker.py:358] prefix_block_tables: [[], [], [], [], []][32m [repeated 41x across cluster][0m Processed prompts: 100%|██████████| 500/500 [00:54<00:00, 9.13it/s]cost time 56.1341028213501 saving output

[36m(RayWorker pid=1138291)[0m INFO 12-05 10:47:26 worker.py:358] prefix_block_tables: []

image

By the way, My prompts are generated by this script:

# generate test prompt
test_table = "|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||"

prompt_template = '''\
You are a helpful assistant in recongnizes the content of tables in markdown format. Here is a table as follows. You need to answer my question about the table.
# Table
{}

# Question
What' s the content in the ({},{}) cells
'''
with open('prompt.txt', 'w') as outer:
    for row in range(50):
        for column in range(10):
            tmp_str = prompt_template.format(test_table, row + 1, column + 1)
            tmp_str
            # outer.write(f"{tmp_str}")
            print(tmp_str.replace("\n", "\\n"), file=outer)

Maybe there are some bugs in my test?

DouHappy avatar Dec 05 '23 03:12 DouHappy

@DouHappy Can you try with calling llm.generate() with one prompt first to warmup? prefix_block_tables=[[],[],[]] indicates that the kv cache for the prefix part hasn't been computed yet. Also, can you share your testing script?

caoshiyi avatar Dec 13 '23 02:12 caoshiyi

@DouHappy Can you try with calling llm.generate() with one prompt first to warmup? prefix_block_tables=[[],[],[]] indicates that the kv cache for the prefix part hasn't been computed yet. Also, can you share your testing script?

My test script:

# %%
# generate test prompt
test_table = "|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||"

prompt_template = '''\
You are a helpful assistant in recongnizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.
# Table
{}

# Question
What' s the content in the ({},{}) cells
'''
with open('prompt.txt', 'w') as outer:
    for row in range(50):
        for column in range(10):
            tmp_str = prompt_template.format(test_table, row + 1, column + 1)
            tmp_str
            # outer.write(f"{tmp_str}")
            print(tmp_str.replace("\n", "\\n"), file=outer)

# %%
import time
import datetime
import os

from vllm import LLM
from vllm import SamplingParams

import torch

def test_prefix(llm = None, sampling_params=None, prompts=None, prefix_len=None, save_file=None):
    # set sampling_params
    if sampling_params == None:
        sampling_params = SamplingParams(temperature=0)

    print("------start generating------")
    start_time = time.time()
    # whether use Prefix
    if prefix_len != None:
        print("warmup")
        outputs = llm.generate(prompts[0], sampling_params=sampling_params, prefix_pos=[prefix_len])
        # start inference
        outputs = llm.generate(prompts, sampling_params=sampling_params, prefix_pos=[prefix_len] * len(prompts))
    else:
        outputs = llm.generate(prompts, sampling_params=sampling_params)

    end_time = time.time()
    print(f"cost time {end_time - start_time}")

    if save_file != None:
        print("saving output......")
        for output in outputs:
            print(output, file=save_file)
        print(f"output saved in {save_file.name} {datetime.datetime.now()}")

# %%
# set gpus
os.environ['CUDA_VISIBLE_DEVICES']="0,1"
# init model and sampling parames
tensor_parallel_size = len(os.getenv('CUDA_VISIBLE_DEVICES').split(','))
# set baichuan model
model = "/data/images/llms/models--baichuan-inc--Baichuan2-13B-Chat"
# model = "/data/images/llms/chatdoc-llama-2-70b-chat-hf-checkpoint-8000"

# Create an LLM.
llm = LLM(model=model, tokenizer_mode='auto', trust_remote_code=True, tensor_parallel_size=tensor_parallel_size)

# %%
# get prompts
prompts = []
with open("prompt.txt", 'r') as reader:
    prompts = reader.readlines()[:500]

with open("output.txt", 'w') as f:
    test_prefix(llm=llm,
                prompts=prompts[:50],
                prefix_len=591,
                save_file=f
                )



I find a bug that when using two GPUs is slower than single GPU. Prefix‘s state 'on_gpu' is always False before prepare_inputs() When using two GPUs. and it works nice on single gpu. It mean multi_query_cached_kv_attention never be used when running on multi-gpus. My last test is also pass on single gpu.

one gpu with prefix
------start generating------
Processed prompts: 100%|██████████| 500/500 [00:22<00:00, 22.55it/s]cost time 23.27279233932495
saving output

one gpu without prefix
------start generating------
Processed prompts: 100%|██████████| 500/500 [01:21<00:00,  6.16it/s]cost time 82.11750793457031
saving output

but it cost about 60s on two gpus. Although It is very fast, I can't got the right output when I use prefix.

DouHappy avatar Dec 13 '23 03:12 DouHappy

I propose the changes below to the PrefixPool class.

  1. There is no need to maintain two data structures: one being a dict from prefix_hash to prefix_id and the other one being the list of prefixes. We can simply maintain just the dictionary, directly from prefix_hash to Prefix.

  2. I think that the prefix pool class should have a max_capacity, otherwise it can have the potential of growing without bounds. Maybe for the common use case the number of prefixes is bounded or limited. But there are some use cases where the number of prefixes can be unlimited. For example: consider a prompts to extract information from a document. Ideally we want to have the document text as a prefix, and then the rest of the prompt would include the specific information to extract. 2.1 Because of that I propose to implement it using an OrderedDict so that we can easily implement a FIFO cache policy.

class PrefixPool:
    """Manages all the prompt prefixes.

    Args:
        block_size: The block size of the executed model.
    
    Attributes:
        prefixes: A list of all the prefixes.
        prefixes_hash: Mapping from the hash of the prefix to the prefix id.
        block_size: The block size of the executed model.
        max_capacity: Max number of prefixes to keep in the pool at any given time.
    """
    id_iter = itertools.count()

    def __init__(
        self,
        block_size: int,
        max_capacity: int = 32
    ) -> None:
        # dict from hash to prefix. It is an OrderedDict so that
        # we can easily implement a FIFO cache policy.
        self.prefixes: dict[int, Prefix] = OrderedDict() 
        
        self.block_size = block_size    # Number of tokens per memory block
        self.max_capacity=max_capacity

    def add_prefix(self, token_ids: list[int]) -> Prefix:
        if len(self.prefixes) >= self.max_capacity:
            self.prefixes.popitem(last=False)

        prefix_hash = hash(tuple(token_ids))
        assert prefix_hash not in self.prefixes
        
        # generate prefix_id and create new prefix
        prefix_id = next(self.id_iter)
        prefix = Prefix(prefix_id, token_ids, self.block_size)
        
        self.prefixes[prefix_hash] = prefix
        return prefix

    # use this first, if we already know from the application which part of the tokens are prefix.
    def fixed_search(self, prefix_hash: int) -> Optional[Prefix]:
        return self.prefixes.get(prefix_hash, None)

jadielam avatar Jan 11 '24 12:01 jadielam

  • I am interested in contributing to this branch. I have opened a PR with some changes to it. Feel free to review it. https://github.com/caoshiyi/vllm/pull/2

  • The changes introduced in that PR do not really effectively add new functionality to the vllm engine right now, but they are stepping stones to make it happen in a later PR.

  • The idea that I have in mind is the following: Whenever the processing of a SequenceGroup is either finished, or aborted, check the prefix_pool: PrefixPool object and see if the prefix (if any) of that SequenceGroup still exists in the PrefixPool. If it does not exists, then free the GPU memory from that prefix.

I have read the code base extensively and I don't think that adding this new functionality affects the behavior of the system as intended originally, given that by default PrefixPool will behave as if no max_capacity has been set.

jadielam avatar Jan 15 '24 18:01 jadielam

@jadielam Thank you for your contribution! How about the following plan: I will fix the bugs of the PR and first merge a version without prefix pool capacity, and then you create a new PR in the vLLM repo to add max_capacity to prefix pool?

zhuohan123 avatar Jan 15 '24 19:01 zhuohan123

@jadielam Thank you for your contribution! How about the following plan: I will fix the bugs of the PR and first merge a version without prefix pool capacity, and then you create a new PR in the vLLM repo to add max_capacity to prefix pool?

sounds good.

jadielam avatar Jan 15 '24 20:01 jadielam

Thanks a lot for this great feature! I tried it with the latest caoshiyi:prefix, but I found that there's no speed improvement. (one V100 GPU, tested with Baichuan2-13B-chat model)

Hi @DouHappy , did you observe any speed improvement afterwards?

franklyd avatar Jan 16 '24 15:01 franklyd

Thanks a lot for this great feature! I tried it with the latest caoshiyi:prefix, but I found that there's no speed improvement. (one V100 GPU, tested with Baichuan2-13B-chat model)

Hi @DouHappy , did you observe any speed improvement afterwards?

Yes,i got observe speed up. Could you should me your test script? Maybe you forgot warmup? BTW, I am trying to introduce prefix but only chinese version now. See this vLLM-prefix浅析(System Prompt,大模型推理加速) @franklyd

DouHappy avatar Jan 17 '24 10:01 DouHappy

@franklyd @DouHappy There was a bug in my refactor. If you try now, you should be able to see speedups.

zhuohan123 avatar Jan 17 '24 23:01 zhuohan123

Thanks a lot for this great feature! I tried it with the latest caoshiyi:prefix, but I found that there's no speed improvement. (one V100 GPU, tested with Baichuan2-13B-chat model) Hi @DouHappy , did you observe any speed improvement afterwards?

Yes,i got observe speed up. Could you should me your test script? Maybe you forgot warmup? BTW, I am trying to introduce prefix but only chinese version now. See this vLLM-prefix浅析(System Prompt,大模型推理加速) @franklyd

Could you provide a test script for the speedup?

gangooteli avatar Mar 11 '24 09:03 gangooteli

Could you provide a test script for the speedup?

+1

ksjadeja avatar Apr 13 '24 23:04 ksjadeja

Hi @HaiShaw

Triton doesn't seem to support mixed precision dot product, so this kernel here fails if the k is uint8 and q is another precision. I've been trying to find a solution to this problem, but coming up with blanks. Do you have any ideas on how to approach this?

AlpinDale avatar Apr 16 '24 20:04 AlpinDale

Hi @HaiShaw

Triton doesn't seem to support mixed precision dot product, so this kernel here fails if the k is uint8 and q is another precision. I've been trying to find a solution to this problem, but coming up with blanks. Do you have any ideas on how to approach this?

Hi, @AlpinDale. Are you using prefix caching with FP8 KVCache? PyTorch and Triton used by vLLM could not support FP8 KVCache. Here are more information about prefix caching and FP8 KVCache in #3234.

chenxu2048 avatar Apr 23 '24 10:04 chenxu2048