vllm
vllm copied to clipboard
[Frontend][Core] Add Guidance backend for guided decoding
This pull request extends guided decoding capabilities
- Add guidance backend
guidance backend supports regex, choice, json and grammar.
relevant: https://github.com/vllm-project/vllm/issues/5245
Usage
- JSON Generation
from pydantic import BaseModel, ConfigDict
model = "Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int4"
llm = LLM(model=model)
class UserProfile(BaseModel):
name: str
age: int
email: str
model_config = ConfigDict(extra="forbid")
sampling_params = SamplingParams(
temperature=0.0,
top_p=0.95,
max_tokens=512,
guided_decoding=GuidedDecodingParams(
json=UserProfile,
backend="guidance",
),
)
outputs = llm.chat(
messages=[
[
CustomChatCompletionMessageParam(
role="system", content="You are a helpful assistant."
),
CustomChatCompletionMessageParam(
role="user",
content="Tell me something about yourself (name, age, email) in JSON format.\n",
),
],
],
sampling_params=[sampling_params],
)
- Choices Generation
sampling_params = SamplingParams(
temperature=0.0,
top_p=0.95,
max_tokens=512,
guided_decoding=GuidedDecodingParams(
choice=["3","4","5","6"],
backend="guidance",
),
)
outputs = llm.chat(
messages=[
[
CustomChatCompletionMessageParam(
role="system", content="You are a 5 years-old helpful assistant."
),
CustomChatCompletionMessageParam(
role="user",
content="How old are you?",
),
],
],
sampling_params=[sampling_params],
)
- Regex Generation via OpenAI Client
model = "Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int4"
client = OpenAI(
base_url="http://localhost:8000/v1",
api_key="NOKEY",
)
completion = client.chat.completions.create(
model=model,
messages=[
{
"role": "user",
"content": "You are a 5 years-old helpful assistant.",
},
{
"role": "user",
"content": """How old are you?""",
},
],
extra_body={"guided_regex": "\\d+", "guided_decoding_backend": "guidance"}
)
👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.
Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.
To run CI, PR reviewers can do one of these:
- Add
readylabel to the PR - Enable auto-merge.
🚀
Thanks @njhill for your quick review. Really appreciate it.
- Presumably the parallelization speedup is due to the fact that the pytorch ops involved release the gil?
That's one reason, another one is the parser (llguidance) used in guidance was implemented in Rust, and it automatically releases GIL when called. So it would be more efficient to run guidance in parallel.
- Were your outlines measurements also using the threadpool?
Yes, experiments were done using threadpool
- It would be good to also try with the latest outlines 0.1.x if possible which is apparently much faster than < 0.1. We would want to upgrade to that too in any case.
I haven't tested outlines 0.1.x yet, just used the current version in VLLM. However, I am not focusing too much on the benchmark for this PR. The focus is to make guidance available as another guided decoding backend to VLLM's community so people can choose what's best for them. :)
I also figured out lm-format-enforcer is not thread-safe. It failed some tests when number of threads is larger than 1. @njhill any suggestions for this?
I also figured out lm-format-enforcer is not thread-safe. It failed some tests when number of threads is larger than 1. @njhill any suggestions for this?
Decided to rollback to single threaded version to not break lm-format-enforcer. The PR is coming with minimal changes to add llguidance as new logits processor. Hope the current code is good for merging :) @njhill @mgoin
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @JC1DA.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
Resolved conflict with newly merged xgrammar
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @JC1DA.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
Resolved conflict with newly merged xgrammar
@njhill @mgoin
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @JC1DA.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
We have just released a large JSON Schema benchmark and a paper. Of particular interest might be isolated mask-generation benchmarks - comparing LLGuidance, Outlines, XGrammar and llama.cpp grammars.
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @JC1DA.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
Hello! I wanted to try this out, so I re-applied the changes on top of main, adjusting as necessary to get it to fit the current state of things.
https://github.com/vllm-project/vllm/compare/main...russellb:vllm:llguidance-v0-integration?expand=1
It's failing on _initialize() in the logits processor right now. Perhaps someone could take a look with me? #forum-structured-output on the vllm slack would be a good place to find me outside of github if you'd like to chat.
$ pytest -sv tests/model_executor/test_guided_processors.py::test_guided_logits_processor_black_box[True-guidance]
...
tests/model_executor/test_guided_processors.py:93:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
vllm/model_executor/guided_decoding/guidance_logits_processors.py:116: in __call__
self._initialize()
vllm/model_executor/guided_decoding/guidance_logits_processors.py:86: in _initialize
TransformersTokenizer( \
vllm/model_executor/guided_decoding/guidance_utils.py:183: in __init__
byte_tokens = self._byte_tokens(transformers_tokenizer)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <vllm.model_executor.guided_decoding.guidance_utils.TransformersTokenizer object at 0x7f52a7c9a840>
transformers_tokenizer = LlamaTokenizerFast(name_or_path='HuggingFaceH4/zephyr-7b-beta', vocab_size=32000, model_max_length=1000000000000000019...ecial=True),
2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
)
def _byte_tokens(
self,
transformers_tokenizer: Union[
"transformers_package.PreTrainedTokenizer",
"transformers_package.PreTrainedTokenizerFast",
],
) -> list[bytes]:
if hasattr(transformers_tokenizer, "byte_decoder"):
try:
self._check_byte_decoder(transformers_tokenizer.byte_decoder,
transformers_tokenizer)
except ByteDecoderError as e:
error_log = f"Tokenizer has a byte_decoder, \
but it can't be used to construct byte_tokens: {e}"
logging.warning(error_log)
pass
else:
return self._byte_tokens_from_byte_decoder(
transformers_tokenizer.byte_decoder,
transformers_tokenizer)
if hasattr(transformers_tokenizer, "sp_model"):
return self._byte_tokens_from_sp_model(transformers_tokenizer)
try:
return self._byte_tokens_by_encoding_token_strings(
transformers_tokenizer)
except ValueError as e:
error_log = f"Could not build byte tokens from the \
tokenizer by encoding token strings: {e}"
logging.warning(error_log)
pass
fallback_byte_decoder = self._fallback_byte_decoder()
try:
self._check_byte_decoder(fallback_byte_decoder,
transformers_tokenizer)
except ByteDecoderError as e:
# Should be the only exception that is raised in _byte_tokens
> raise ByteTokensError(
"Could not build byte tokens from the tokenizer, \
and falling back to a standard gpt2 byte_decoder failed"
) from e
E vllm.model_executor.guided_decoding.guidance_utils.ByteTokensError: Could not build byte tokens from the tokenizer, and falling back to a standard gpt2 byte_decoder failed
vllm/model_executor/guided_decoding/guidance_utils.py:288: ByteTokensError
----------------------------------------------------------------- Captured log call ------------------------------------------------------------------
WARNING root:guidance_utils.py:279 Could not build byte tokens from the tokenizer by encoding token strings: Round-trip encoding of tokens [<0x00>] failed! Got [1, 28705, 3]
============================================================== short test summary info ===============================================================
FAILED tests/model_executor/test_guided_processors.py::test_guided_logits_processor_black_box[True-guidance] - vllm.model_executor.guided_decoding.guidance_utils.ByteTokensError: Could not build byte tokens from the tokenizer, and falli...
Hey @russellb! We've been tracking the discussion on https://github.com/vllm-project/vllm/pull/12388 . Our plan is to re-do this PR once that gets merged. llguidance exposes a similar API to xgrammar so it'll be quite a bit easier to just drop our code in at that point.
Happy to get started on it whenever you recommend. Thanks for the pointer on the slack, we'll join and chat there too :)
@lochuynh1412 @mmoskal
Hey @russellb! We've been tracking the discussion on #12388 . Our plan is to re-do this PR once that gets merged. llguidance exposes a similar API to xgrammar so it'll be quite a bit easier to just drop our code in at that point.
Happy to get started on it whenever you recommend. Thanks for the pointer on the slack, we'll join and chat there too :)
@lochuynh1412 @mmoskal
That sounds great. I want to get multiple backends going for the V1 engine after that PR merges.
I also might have a use case for this in the V0 engine for an existing user, as well, which brought me over to this PR. I figured I might be able to help get this updated and working so I can test and see if it works for them.
superceded by https://github.com/vllm-project/vllm/pull/14589