vllm icon indicating copy to clipboard operation
vllm copied to clipboard

[ V0 ][ sample ] improve sample performance when using guide decoding

Open cjsdurj opened this issue 9 months ago • 13 comments

Introduction

this pr introduce a new sample algorithm ( SampleV2) to improve guide decoding performance.

currently support xgrammer backend and qwen , llama model .

refer to interface 'vllm. model_executor. models. interfaces.SupportsSampleV2'.

this PR increased ebnf guide decode throughput by more than 1000% .

background

  1. guide decode backends like xgr , outlines open only optimizes the speed of json guide decoding ebnf , but when use other gbnf grammar , the throughput of tokens can be very slow.

in my test case, the grammar is below:

root        ::= en-char+ ([ \t\n] en-char+)*
en-char     ::= letter | digit | punctuation
letter      ::= [a-zA-Z]
digit       ::= [0-9]
punctuation ::= [!"#$%&'()*+,-./:;<=>?@[\\\]^_`{|}~]

when vllm use 4x L40s GPU serve a qwen72b awq 4bit model, the decode speed is 50 tokens/s. but applied guided_grammar, decode speed is 2.5 tokens /s .

I optimized the grammar as below , then it outputs 12 tokens/s:

root        ::= (en-char+ [ \t\n])*  en-char+
en-char     ::= letter | digit | punctuation
letter      ::= [a-zA-Z]
digit       ::= [0-9]
punctuation ::= [!"#$%&'()*+,-./:;<=>?@[\\\]^_`{|}~]
  1. the bottleneck of guide decode is matcher.FillNextTokenBitmask (see this issue https://github.com/mlc-ai/xgrammar/issues/235) . before in every decode step the function FillNextTokenBitmask is called , so this pr aimed to reduce the times to call matcher.FillNextTokenBitmask.

how it works

    def samplev2(
            self,
            logits: torch.Tensor,
            sampling_metadata: SamplingMetadata,
    )->Optional[SamplerOutput]:
        # compute logits
        next_tokens: SamplerOutput = self.sampler(logits, sampling_metadata)

        # check if  the sampled tokens fit the grammars
        tks = torch.tensor([o.samples[0].output_token for o in next_tokens.outputs])
        accepted = accept_grammar(tks, sampling_metadata)
        need_resample = torch.logical_not(accepted)
        if accepted.all():
            return next_tokens
        # resample
        # if the token is not valid, sample again, but first apply the grammar bitmask
        # only apply logits processor when need_resample
        logits = _apply_logits_processors(logits, sampling_metadata, need_resample, False)
        new_next_tokens: SamplerOutput = self.sampler(logits, sampling_metadata)

        for i, replace in enumerate(need_resample.tolist()):
            if replace:
                next_tokens.outputs[i] = new_next_tokens.outputs[i]

        tks = torch.tensor([o.samples[0].output_token for o in next_tokens.outputs])
        # matcher only accept next token when first round is not accepted.
        accepted = accept_grammar(tks, sampling_metadata, need_resample)
        assert accepted.all()
        return next_tokens

The new sampling method (samplev2) consists of the following steps:

a. compute logits from hidden_states. b. sample the batched outout without grammar guide. c. let FSM try to accept the output tokens , if call accept, return directly ( did not call FillNextTokenBitmask). d. if FSM can not accept , apply grammar guide to logits , and resample .

test

ebnfstr = '''
root        ::= en-char+ ([ \t\n] en-char+)*
en-char     ::= letter | digit | punctuation
letter      ::= [a-zA-Z]
digit       ::= [0-9]
punctuation ::= [!"#$%&'()*+,-./:;<=>?@[\\\]^_`{|}~]
'''



payload = {
        "messages": [
            {
                "content": "tell a story about Spring ,at least 1024 words",
                "role": "user"
            }
        ],
        "max_tokens": 1024,
        "model": "llama2",
        "stream": False,
        "guided_grammar":  ebnfstr
    }

single clinet test result (nearly no cost when apply guided_grammar):

model not guided (tokens /s) before (tokens /s) this pr (tokens /s)
qwen2.5 1.5b fp16 1*L40s 140 2 136
qwen2.5 72b awq4 4*L40s 50 2 48

current this pr only supported backend xgrammar and models like qwen2 ,llama . more model can be supported by extend class SupportsSampleV2

cjsdurj avatar Mar 17 '25 14:03 cjsdurj

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

github-actions[bot] avatar Mar 17 '25 14:03 github-actions[bot]

Sorry, can you merge from latest main to fix the pre-commit failures?

DarkLight1337 avatar Mar 17 '25 14:03 DarkLight1337

Sorry, can you merge from latest main to fix the pre-commit failures?

ok. l have fixed the pre-commit faulures

cjsdurj avatar Mar 21 '25 09:03 cjsdurj

Sorry, can you merge from latest main to fix the pre-commit failures?

ok. l have fixed the pre-commit faulures

All commit messages also need Signed-off-by headers to make the DCO check pass.

About the change -- the results are impressive, but I'm a bit concerned about what we'd do with this for V1. As written, this will only work with the feature for V0, but we're trying to focus our enhancements on V1 as much as possible.

Have you compared this to V1? That would be good as another column in your comparison. In other words, does V0 + this enhancement beat V1 with structured output in use? Or do the other enhancements in V1 already make V1 faster without this optimization in place?

russellb avatar Mar 21 '25 13:03 russellb

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @cjsdurj.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Mar 21 '25 16:03 mergify[bot]

@cjsdurj did you mean to close this?

russellb avatar Mar 21 '25 20:03 russellb

@cjsdurj did you mean to close this? sorry. I have make a mistake last night when address DCO check.

  • this feature also work on V1. and I am working on it.

  • I have test v1 engine , the gbnf decode performance is as slow as v0 .

because it changes too many py files , currently it only works on V0 and qwen llama model as a preview feature.

cjsdurj avatar Mar 22 '25 04:03 cjsdurj

@cjsdurj did you mean to close this? sorry. I have make a mistake last night when address DCO check.

  • this feature also work on V1. and I am working on it.
  • I have test v1 engine , the gbnf decode performance is as slow as v0 .

because it changes too many py files , currently it only works on V0 and qwen llama model as a preview feature.

if this works ok. In subsequent submissions, I will Implement this feature in V1 engine for more models.

cjsdurj avatar Mar 22 '25 04:03 cjsdurj

@cjsdurj did you mean to close this? sorry. I have make a mistake last night when address DCO check.

  • this feature also work on V1. and I am working on it.
  • I have test v1 engine , the gbnf decode performance is as slow as v0 .

because it changes too many py files , currently it only works on V0 and qwen llama model as a preview feature.

if this works ok. In subsequent submissions, I will Implement this feature in V1 engine for more models.

It will not work with V1 by design right now. In V1, advancing the grammar's FSM and applying the bitmask are in separate processes.

I'd like to see more performance numbers, in particular with large batches of requests and not just a single request. I'll do some benchmarking at some point if you don't have the hardware for it (I'll want to see some H100 results).

russellb avatar Mar 22 '25 14:03 russellb

https://github.com/vllm-project/vllm/pull/17084 removed sampler from model, this PR needs rebase.

~Let me see if I can help~ https://github.com/cjsdurj/vllm/pull/1

lk-chen avatar Apr 29 '25 19:04 lk-chen

https://github.com/vllm-project/vllm/pull/17084 removed sampler from model, this PR needs rebase. Let me see if I can help

lk-chen avatar Apr 29 '25 19:04 lk-chen

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @cjsdurj.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Apr 29 '25 21:04 mergify[bot]

Hi @cjsdurj , may I ask how to produce before throughput of 2tk/s and after throughput of 136 tk/s ?

I'm using https://github.com/lk-chen/vllm/pull/2 on L40S, forcing vLLM v0, model=Qwen/Qwen2.5-1.5B-Instruct, async mode, and got

num_prompts output trpt. before this pr (tk/s) output trpt. after this pr
1 12.94 12.95
100 69.41 79.33

lk-chen avatar Apr 30 '25 18:04 lk-chen

Since this only applies to V0 as written, I'm going to close this out for now. V0 is deprecated and we're not adding major changes at this point. I think significant changes are necessary to come up with something that works with V1. Please feel free to reopen this or open a new PR if you come up with an approach that works in V1. Thank you!

russellb avatar May 05 '25 22:05 russellb