[ V0 ][ sample ] improve sample performance when using guide decoding
Introduction
this pr introduce a new sample algorithm ( SampleV2) to improve guide decoding performance.
currently support xgrammer backend and qwen , llama model .
refer to interface 'vllm. model_executor. models. interfaces.SupportsSampleV2'.
this PR increased ebnf guide decode throughput by more than 1000% .
background
- guide decode backends like xgr , outlines open only optimizes the speed of json guide decoding ebnf , but when use other gbnf grammar , the throughput of tokens can be very slow.
in my test case, the grammar is below:
root ::= en-char+ ([ \t\n] en-char+)*
en-char ::= letter | digit | punctuation
letter ::= [a-zA-Z]
digit ::= [0-9]
punctuation ::= [!"#$%&'()*+,-./:;<=>?@[\\\]^_`{|}~]
when vllm use 4x L40s GPU serve a qwen72b awq 4bit model, the decode speed is 50 tokens/s. but applied guided_grammar, decode speed is 2.5 tokens /s .
I optimized the grammar as below , then it outputs 12 tokens/s:
root ::= (en-char+ [ \t\n])* en-char+
en-char ::= letter | digit | punctuation
letter ::= [a-zA-Z]
digit ::= [0-9]
punctuation ::= [!"#$%&'()*+,-./:;<=>?@[\\\]^_`{|}~]
- the bottleneck of guide decode is matcher.FillNextTokenBitmask (see this issue https://github.com/mlc-ai/xgrammar/issues/235) . before in every decode step the function FillNextTokenBitmask is called , so this pr aimed to reduce the times to call matcher.FillNextTokenBitmask.
how it works
def samplev2(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
)->Optional[SamplerOutput]:
# compute logits
next_tokens: SamplerOutput = self.sampler(logits, sampling_metadata)
# check if the sampled tokens fit the grammars
tks = torch.tensor([o.samples[0].output_token for o in next_tokens.outputs])
accepted = accept_grammar(tks, sampling_metadata)
need_resample = torch.logical_not(accepted)
if accepted.all():
return next_tokens
# resample
# if the token is not valid, sample again, but first apply the grammar bitmask
# only apply logits processor when need_resample
logits = _apply_logits_processors(logits, sampling_metadata, need_resample, False)
new_next_tokens: SamplerOutput = self.sampler(logits, sampling_metadata)
for i, replace in enumerate(need_resample.tolist()):
if replace:
next_tokens.outputs[i] = new_next_tokens.outputs[i]
tks = torch.tensor([o.samples[0].output_token for o in next_tokens.outputs])
# matcher only accept next token when first round is not accepted.
accepted = accept_grammar(tks, sampling_metadata, need_resample)
assert accepted.all()
return next_tokens
The new sampling method (samplev2) consists of the following steps:
a. compute logits from hidden_states. b. sample the batched outout without grammar guide. c. let FSM try to accept the output tokens , if call accept, return directly ( did not call FillNextTokenBitmask). d. if FSM can not accept , apply grammar guide to logits , and resample .
test
ebnfstr = '''
root ::= en-char+ ([ \t\n] en-char+)*
en-char ::= letter | digit | punctuation
letter ::= [a-zA-Z]
digit ::= [0-9]
punctuation ::= [!"#$%&'()*+,-./:;<=>?@[\\\]^_`{|}~]
'''
payload = {
"messages": [
{
"content": "tell a story about Spring ,at least 1024 words",
"role": "user"
}
],
"max_tokens": 1024,
"model": "llama2",
"stream": False,
"guided_grammar": ebnfstr
}
single clinet test result (nearly no cost when apply guided_grammar):
| model | not guided (tokens /s) | before (tokens /s) | this pr (tokens /s) |
|---|---|---|---|
| qwen2.5 1.5b fp16 1*L40s | 140 | 2 | 136 |
| qwen2.5 72b awq4 4*L40s | 50 | 2 | 48 |
current this pr only supported backend xgrammar and models like qwen2 ,llama . more model can be supported by extend class SupportsSampleV2
👋 Hi! Thank you for contributing to the vLLM project.
💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.
Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.
To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.
🚀
Sorry, can you merge from latest main to fix the pre-commit failures?
Sorry, can you merge from latest main to fix the pre-commit failures?
ok. l have fixed the pre-commit faulures
Sorry, can you merge from latest main to fix the pre-commit failures?
ok. l have fixed the pre-commit faulures
All commit messages also need Signed-off-by headers to make the DCO check pass.
About the change -- the results are impressive, but I'm a bit concerned about what we'd do with this for V1. As written, this will only work with the feature for V0, but we're trying to focus our enhancements on V1 as much as possible.
Have you compared this to V1? That would be good as another column in your comparison. In other words, does V0 + this enhancement beat V1 with structured output in use? Or do the other enhancements in V1 already make V1 faster without this optimization in place?
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @cjsdurj.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
@cjsdurj did you mean to close this?
@cjsdurj did you mean to close this? sorry. I have make a mistake last night when address DCO check.
-
this feature also work on V1. and I am working on it.
-
I have test v1 engine , the gbnf decode performance is as slow as v0 .
because it changes too many py files , currently it only works on V0 and qwen llama model as a preview feature.
@cjsdurj did you mean to close this? sorry. I have make a mistake last night when address DCO check.
- this feature also work on V1. and I am working on it.
- I have test v1 engine , the gbnf decode performance is as slow as v0 .
because it changes too many py files , currently it only works on V0 and qwen llama model as a preview feature.
if this works ok. In subsequent submissions, I will Implement this feature in V1 engine for more models.
@cjsdurj did you mean to close this? sorry. I have make a mistake last night when address DCO check.
- this feature also work on V1. and I am working on it.
- I have test v1 engine , the gbnf decode performance is as slow as v0 .
because it changes too many py files , currently it only works on V0 and qwen llama model as a preview feature.
if this works ok. In subsequent submissions, I will Implement this feature in V1 engine for more models.
It will not work with V1 by design right now. In V1, advancing the grammar's FSM and applying the bitmask are in separate processes.
I'd like to see more performance numbers, in particular with large batches of requests and not just a single request. I'll do some benchmarking at some point if you don't have the hardware for it (I'll want to see some H100 results).
https://github.com/vllm-project/vllm/pull/17084 removed sampler from model, this PR needs rebase.
~Let me see if I can help~ https://github.com/cjsdurj/vllm/pull/1
https://github.com/vllm-project/vllm/pull/17084 removed sampler from model, this PR needs rebase. Let me see if I can help
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @cjsdurj.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
Hi @cjsdurj , may I ask how to produce before throughput of 2tk/s and after throughput of 136 tk/s ?
I'm using https://github.com/lk-chen/vllm/pull/2 on L40S, forcing vLLM v0, model=Qwen/Qwen2.5-1.5B-Instruct, async mode, and got
| num_prompts | output trpt. before this pr (tk/s) | output trpt. after this pr |
|---|---|---|
| 1 | 12.94 | 12.95 |
| 100 | 69.41 | 79.33 |
Since this only applies to V0 as written, I'm going to close this out for now. V0 is deprecated and we're not adding major changes at this point. I think significant changes are necessary to come up with something that works with V1. Please feel free to reopen this or open a new PR if you come up with an approach that works in V1. Thank you!