vllm icon indicating copy to clipboard operation
vllm copied to clipboard

[RFC]: BatchLLM for better shared prefix utilizing in offline scenarios

Open xinji1 opened this issue 1 month ago • 10 comments

Motivation.

This request is mainly for offline inference scenarios , based on the paper BatchLLM

TL; DR: Currently, vllm performs implicit (or just_in_time) shared prefix identifying and metadata collecting, and then performs cascade attention when there's one single shared prefix for all requests, according to the PR #11635. However it does not utilize the shared prefix fully under offline scenarios, where there're a lot of requests with different shared prefixes. This PR tries to alleviate the following pain points of vllm's inference .

  • Point 1: Currently vllm's inference with prefix-caching and cascade attention cannot gather all requests with the same common prefix together (it's essential since all query tokens with the same common prefix have to be treated as if they are from the same request, for the attention calculation)

  • Point 2: Under offline scenarios, it's not necessary to perform implicit shared prefix identifying since all requests are ready before the inference starts. Implicit prefix caching is not the best way to manage the kv-cache of shared prefix tokens.

  • Point 3: When it comes to vllm's cascade attention, it cannot support different common prefixes for different requests in one batch.

How BatchLLM tries to alleviate them

  • For the Point 1, one simple and easy way is to use a sorted() function (like python.sorted()) to sort all samples in a dataset before the inference starts. Here we try to gather the requests with the same prefix together, identify the shared prefixes of different requests explicitly, and enlarge the shared prefix as much as possible.

  • For the Point 2, we're trying to introduce the concept of "prefix-sharing group", where a mini-set of requests share the same common prefix. If the original requests look like:

a,b,c,d,e, f,g,h...
{---X---},{--Y1--}

a,b,c,d,e, i,j,k...
{---X---},{--Y2--}

...

here's how a prefix-sharing group looks like:

List[List, List[List]]: [X, [Y1, Y2,...]]

where we put the common prefix X as the first element, followed by a list of all the non-shared context Y1 & Y2. If there're 2 requests sharing the same prefix, we'll separate them into the common prefix and the other 2 non-shared context. In this way, vLLM will handle 1+2=3 requests, meaning that BatchLLM will inference and save the kv-cache of the common prefix first ( as a single request without any decoding operation), then generate tokens according to the other 2 non-shared context/requests. Finally, when the inference of all requests in a prefix-sharing group is done, the kv-cache of the common part would be released.

  • For the point 3, based on the above changes, it's much easier to collect the meta-info related to cascade attention. However current flash-attn kernels cannot support the cases, when there're different common prefixes for different requests in one single batch. We've achieve a triton version for the common/distinct/merge_2 kernels (like the PR #11635), showing good performance even with some extra Triton overheads.

See below for the performance improvement

  • model: llama-3.1-8b
  • GPU: single A100
  • setting:
    • no cuda_graph & multi-step decoding
    • for the vllm baseline, chunked-prefill(max_tokens in one batch is 2048) & prefix-caching are both enabled, and ~~cascade inference v1 is enabled default after the PR #11635~~ (Found that it needs the "VLLM_USE_V1", add the experiment too.)
    • 6400 requests, each 400 of them share the same common prefix
    • each request have 2200 tokens, while the length of common prefix is 2000, and the length of non-shared context is 200.
  • result:
setting throughput
vllm + chunked-prefill + prefix-caching, after v0.6.6 post1 (commit: 5340a30) 6.62
vllm + chunked-prefill + prefix-caching + python.sorted(), after v0.6.6 post1 13.17
vllm + chunked-prefill + prefix-caching + python.sorted(), after v0.6.6 post1, VLLM_USE_V1=1 10.78
Our implementation based on v0.6.4 18.01
  • After changing to different sharing degree & different length settings of shared prefix (the following tests are based on vLLM v0.6.4 and SGLang v0.4.1):

Image

Image

Proposed Change.

High level

  1. A preprocess part for the building of "prefix-sharing group", where BatchLLM will gather the requests with the same prefix together, identify the shared prefixes of different requests explicitly, and enlarge the shared prefix as much as possible . And we put the preprocess codes in llm.py.
  2. A new manager for managing the request of shared prefix/ non-shared context. For example, release all the blocks of shared prefix after all requests in one prefix-sharing group are inferenced.
  3. A new backend based on FlashAttnBackend, according to the reviewer here. Here we need to collect the meta-info of different prefix-sharing groups and perform attention calculation. Currently we use the triton kernels we've implemented.

Feedback Period.

No response

CC List.

@WoosukKwon @comaniac @pavanimajety

Any Other Things.

  • the script used in Motivation

the cmd is python vllm_baseline.py --model_path /workspace/llama3_8b --request_num 6400 --context_len 2000 --prompt_len 200 --generate_len 100 --sharing_degree 16 --rand --use_prefix_caching --chunk

import torch
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
import time
import datetime
from pandas import read_table
import math
import argparse

import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

import random
random.seed(123)
torch.manual_seed(123)

def parse_args():
    parser = argparse.ArgumentParser(description='vLLM performance test')
    parser.add_argument('--use_cuda_graph', type=bool, default=False,
                        help='use_cuda_graph')
    parser.add_argument('--model_path', type=str, default='/workspace/llama-3.1-8b',
                        help='model_path')
    parser.add_argument('--request_num', type=int, default=6400,
                        help='request_num')
    parser.add_argument('--context_len', type=int, default=2000,
                        help='context_len')
    parser.add_argument('--prompt_len', type=int, default=200,
                        help='prompt_len')
    parser.add_argument('--generate_len', type=int, default=100,
                        help='generate_len')
    parser.add_argument('--use_prefix_caching', action="store_true",
                        help='use_prefix_caching')
    parser.add_argument('--chunk', action="store_true",
                        help='chunk')
    parser.add_argument('--rand', action="store_true",
                        help='whether the input is randomly shuffled')
    parser.add_argument('--sharing_degree', type=int, default=2,
                        help='how many requests share the same prefix')
    parser.add_argument('--random_generate', action="store_true",
                        help='not ignore eos')

    return parser.parse_args()


def prepare_tokens(tokenizer, context_len, prompt_len, group_num, batch_size):
    share = []
    all_t = []
    group_idx = []
    for i in range(group_num):
        context_this_group = torch.randint(1, 20000, (context_len,))
        share.append(context_this_group.tolist())
        for j in range(batch_size):
            prompt_this_request = torch.randint(1, 20000, (prompt_len,))
            all_t.append(
                torch.concat((context_this_group[0:context_len], prompt_this_request[0:prompt_len]), 0).tolist())
            group_idx.append(i)
    return all_t, group_num, group_num * batch_size, share, group_idx


def build_pipeline(engine_path, prefix_caching, chunk=False):
    kwargs = {}
    if "70B" in engine_path or "70b" in engine_path:
        print("70B ")
        kwargs  ={"quantization":"gptq", 
                  
                  "max_model_len":8192
                  }
        # os.environ["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = 1

    if chunk:
        pipe = LLM(model=engine_path,
                dtype= "float16",
                tensor_parallel_size=1,
                enable_prefix_caching=prefix_caching,
                enforce_eager= True,
                disable_sliding_window=True,
                enable_chunked_prefill=True,
                max_num_batched_tokens=2048,
                **kwargs,
                )
    else:
        if "llama" not in engine_path:
            kwargs = {
                **kwargs,
                "max_num_batched_tokens":32768,
            }
        pipe = LLM(model=engine_path,
                dtype= "float16",
               
               tensor_parallel_size=1,
               enable_prefix_caching=prefix_caching,
               enforce_eager=True,
               disable_sliding_window=True,
               **kwargs,
               )
    return pipe


def run_pipeline(input_tokens, pipe, group_id_list=None, max_tokens=100, random_generate=False):
    sampling_params = SamplingParams(temperature=0.01,
                                     top_p=0.1,
                                     max_tokens=max_tokens,
                                     ignore_eos=(not random_generate)
                                     )
    t1 = time.time()
    output = pipe.generate(prompt_token_ids=input_tokens,
                           sampling_params=sampling_params
                           )
    t2 = time.time()
    return output, t2 - t1


def prepare_baseline_caption_token(tokenizer, file_path):
    # Read the caption file
    f = read_table(file_path)
    f_shape = f.shape
    tokens = []
    for i in range(f_shape[0]):
        token1 = tokenizer.encode(f.iloc[i, 7], add_special_tokens=False, return_tensors='pt')
        token2 = tokenizer.encode(f.iloc[i, 8], add_special_tokens=False, return_tensors='pt')
        tokens.append(torch.concat((token1[0, :], token2[0, :]), 0).tolist())
    return tokens, f_shape[0], f_shape[0]


args = parse_args()
batch_size_list = [int(args.sharing_degree)]
engine_path = args.model_path
request_num = args.request_num
context_len = args.context_len
prompt_len = args.prompt_len
generate_len = args.generate_len
use_cuda_graph = args.use_cuda_graph
prefix_caching = args.use_prefix_caching
tokenizer = AutoTokenizer.from_pretrained(engine_path)

file_name = f'baseline_066_random_{args.rand}_sd_{args.sharing_degree}_baseline_{request_num}_{context_len}_{prompt_len}_{generate_len}_chunk_{args.chunk}_{args.model_path}_prefix_caching_{str(prefix_caching)}.txt'.replace("/","_")
f = open(file_name, 'w')
print(f'////////////////////////////', file=f)
print(f' {datetime.datetime.now()} vllm performace test', file=f)

pipe = build_pipeline(engine_path, prefix_caching, args.chunk)
print(f'engine_path:{engine_path},prefix_caching:{prefix_caching}', file=f)
print(f'group_num\tprompt_num\ttime\tthroughput', file=f)


# benchmark test
for batch_size in batch_size_list:
    group_num = request_num // batch_size
    input_tokens, group_num, prompt_num, share_tokens, group_idx = prepare_tokens(tokenizer, context_len, prompt_len,
                                                                                  group_num, batch_size)
    if args.rand:
        random.shuffle(input_tokens)
    
    gen_share_time = 0

    final_output, gen_time = run_pipeline(input_tokens, pipe, group_idx, generate_len, args.random_generate)


    print(f'{group_num:9}\t{prompt_num:10}\t{gen_time:8.2f}\t{prompt_num / (gen_time):10.2f}')
    print(f'{group_num:9}\t{prompt_num:10}\t{gen_time:8.2f}\t{prompt_num / (gen_time):10.2f}', file=f)


Before submitting a new issue...

  • [x] Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

xinji1 avatar Jan 15 '25 11:01 xinji1