vllm
vllm copied to clipboard
[V1][Core] Support for Structured Outputs
This PR introduces the first iteration of structured output support for the V1 engine.
While functional, it is not feature complete with the support in V0. We currently only support xgrammar as a backend. We do not have a fallback in place to outlines as we did in V0. Other backends will come in a follow-up.
While one of the goals in V1 is to minimize or eliminate conflicts between features, this does not yet work with speculative decoding. The features were developed in parallel and we haven’t had a chance to consider the interoperability challenges in depth. This will also be considered as possible follow-up work.
Some key points of the current design include:
- Compilation of the grammar for a structured output request is done asynchronously and will not block the scheduler or any other requests from getting scheduled in the meantime.
- We keep a cache of compiled grammars for accelerating the case where the same grammar is used repeatedly.
- Advancing the FSM and calculating the next logits bitmask is done in the scheduler and then broadcasted to the GPU workers with the rest of the inputs already being sent.
- We arrange the bitmasks in a single tensor to be applied to the full batch of logits in a single operation.
There are several ideas on how this design might evolve. By using this as a functional starting point, we will be able to evaluate changes using benchmarks.
Benchmarks comparing the V0 implementation versus this new V1 implementation:
Co-authored-by: Russell Bryant [email protected] Co-authored-by: Michael Goin [email protected] Co-authored-by: Woosuk Kwon [email protected] Signed-off-by: Aaron Pham [email protected]
👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.
Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.
To run CI, PR reviewers can do one of these:
- Add
readylabel to the PR - Enable auto-merge.
🚀
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @aarnphm.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
I love seeing structured decoding being intergrated deeply inside of vLLM!
I would love to see llguidance being supported though. Compared to XGrammar, it is significantly faster, has near-zero compilation time, and has much broader JSON Schema support. We've been using it in production instances.
If needed I'm happy to add additional APIs to the Python bindings (server-side integrations so far have been native) or otherwise help.
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @aarnphm.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @aarnphm.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @aarnphm.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
Hi, There are some issues/bugs I ran into when I ran json_object mode using xgrammar as the backend :
- Decouple Json from Json Object - Issue : (https://github.com/vllm-project/vllm/issues/13429) (I have a fix for it but its kind of hacky, will try to push a PR with a suitable fix)
- Async grammar compilation leads to the cache being of non type hence , GG is not enabled : Code snippet for looking into the fix .
- FSM update and matching blocked by request status -Pull Request for the fix Issue : https://github.com/vllm-project/vllm/issues/13433
Please do let me know if I have wrongly mentioned any of the any of the above as a bug ! I have a PR for point 3 here, do let me know if we can merge it ! Thank you ! cc: @russellb and @mgoin
Thanks for the progress! Please let me know when this PR is ready for review!
Thanks for the progress! Please let me know when this PR is ready for review!
Will do. This PR should be marked as a draft, though I don't have permission to do that.
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @aarnphm.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @aarnphm.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
Chiming in to +1 @Ubospica with a few comments:
- Using a single bitmask with (batch_size, vocab_size/32) is an interesting choice. It seems like a poor decision at first glance, since not all requests have to use guided decoding and therefore we might over-allocate. But given the size of the bitmasks (0.5MB total for a batch size of 1024) it might even make sense to have a tensor of size (max_batch_size, vocab_size/32) resident on the GPU and then use masked batched calls to operate on indexed slices of this tensor. In theory this would let us batch all of the calls for guided decoding in a batch into a single copy and a single kernel invocation, but that might be wishful thinking.
- I think it is a good decision to overlap advance_state and fill_bitmask wherever possible, keep it as lazy as possible until the data needs to be consumed. However there are some downsides to this as discussed previously, I think we will have to try both and measure the performance.
Also agree about guided + speculative decoding, I have a working local implementation of this for V0 with xgrammar and am happy to pitch in when V1 is ready for this feature.
@benchislett for 128k tokenizer, 1024 batch*128k vocab / 8 bits per byte = 16MB; not that it matters all that much...
Though that would be the bit-compressed size of the bitmask, it will be 8x this size in tensor form. So pre-allocating a full size mask would take 128MB. That's still probably ok, but I am looking into sharing a single such pre-allocated mask since there are other sampling params which require masking too and multiple 128MB allocations start to become problematic.
In fact I expect in practice the structured decoding mask to be quite sparse with either most bits on or most bits off. It would likely be more efficient to update the tensor mask from a sparse representation too (i.e. a list of allowed or disallowed token ids), otherwise I guess we would need a 128k loop for each request to convert the mask each time?
@njhill Could you explain what you mean by tensor form? Why would it be 8x larger? To my knowledge, the shape of the bitmask tensor is always (batch_size, vocab_size/32), at least in xgrammar.
@njhill ideally you want the masking to happen during your softmax computation; it's easy to operate on bitmasks there
@benchislett I just mean if you are doing regular torch tensor masking of the logits with a boolean tensor mask.
@mmoskal you mean with a custom kernel?
@njhill yes with a custom kernel; if you want to do masking separately, xgrammar ships with a custom kernel for that, but you can also do it with a bit of torch.compile - seems to work as fast as the xgrammar version https://github.com/guidance-ai/llguidance/blob/main/python/llguidance/torch.py#L30
Ah thanks @mmoskal, sorry my statement about looping was just completely wrong! I see it's not too bad to expand it after all.
Regarding the overall design of logit processors in V1, any reason why the grammar processor needs to live in the main (scheduler) process? The overall goal of V1 seems to be minimise CPU contention with the scheduler, so adding the grammar in the same process seems to defeat the purpose. This also requires cross-process transfer for the bitmask which seems very inefficient.
IMO it would be much more scalable to have grammar processor owned by the first worker, with the grammar processor abstracted away from the scheduler's POV.
@andylolu2 we were just discussing this exact thing in the slack channel. But a decision was already made to take this approach for the first pass, we could perhaps reevaluate in the next iteration. I think part of the reason is due to the fact that the scheduling can depend on the result of the grammar processor including whether it's finished in time for the next step.
@njhill I think the torch.compile kernel fuses the kernels and doesn't actually expand. But I may be wrong.
@andylolu2 we were just discussing this exact thing in the slack channel. But a decision was already made to take this approach for the first pass, we could perhaps reevaluate in the next iteration. I think part of the reason is due to the fact that the scheduling can depend on the result of the grammar processor including whether it's finished in time for the next step.
Here was the design doc that was written ahead of this iteration of the implementation:
https://docs.google.com/document/d/1H6m_Y3FLJ1FYGCmjXdZzoJv-JCDSxnKuSY2XiAj-c6c/edit?tab=t.0#heading=h.c99v8j7ypbym
I intend to convert this to a markdown file to be included in the docs when we're close to wrapping this up.
For the "why in the scheduler" question, the key factor was we thought it was going to make jump decoding easier to implement. We're not trying to implement that in this iteration, though. It may still turn out that moving it might be better, but I think getting to something that works is the most important next step because all future changes can be more data-driven based on their impact on benchmarks.
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @aarnphm.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @aarnphm.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
There are a number of things to consider regarding compatibility with speculative decoding. It is not clear to me how your code is handling these things, so I will list what I consider fundamental considerations and I hope that you might direct me to how/where they are handled in this implementation.
- The proposed tokens need to be checked for validity according to the matcher. For each draft token, it should either be restricted to be sampled from the guided distribution (for draft-model spec) or it should be forcibly rejected/omitted if it would not be accepted by the matcher (for ngram spec).
- The matcher state should be reset after drafting tokens so that the masking/checking of draft tokens stays consistent between iterations. This is fairly simple for ngram, but one should be careful to roll back the right number of states: if the matcher rejects a token its state is not advanced. So when calling advance_state on a 10-token sequence for which the first 4 are valid and rest are rejected, the subsequent rollback should only be called with N=4. This must be handled precisely to not affect the matcher state.
- The scorer should apply guidance to the logits for each sample in the scoring sequence. This is so that the bonus token can be sampled from the masked distribution no matter which draft tokens are accepted/rejected by the target model. This may involve generating N+1 masks, one for each prefix sequence being scored, depending on how scoring is implemented. For each advancement made to generate these scoring masks, the matcher state should be rolled back until the end of the prefix of proposal tokens which are accepted by the target model.
- The proposed tokens need to be checked for validity according to the matcher. For each draft token, it should either be restricted to be sampled from the guided distribution (for draft-model spec) or it should be forcibly rejected/omitted if it would not be accepted by the matcher (for ngram spec). Currently, the proposed # of tokens are being validated with
accept_tokenand should be rollback accordingly for all cases where it didn't support advancing the FSM.
I don't think that we would have full compatibility with spec decode on this branch yet, in which we will look into after this PR.
- The matcher state should be reset after drafting tokens so that the masking/checking of draft tokens stays consistent between iterations. This is fairly simple for ngram, but one should be careful to roll back the right number of states: if the matcher rejects a token its state is not advanced. So when calling advance_state on a 10-token sequence for which the first 4 are valid and rest are rejected, the subsequent rollback should only be called with N=4. This must be handled precisely to not affect the matcher state.
I think this makes sense, but currently implementation is not yet supported this afaik.
I would want to retract my statement on slack that we yet to have full support for speculative decode in this case.
I'll add a change to cleanly reject guided decode requests when spec decode is enabled. We can address that in a follow-up.
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @aarnphm.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
@aarnphm proposed text for the PR description:
Updated PR title
[V1][Core] Structured Output support
Updated PR description
This PR introduces the first iteration of structured output support for the V1 engine.
While functional, it is not feature complete with the support in V0. We currently only support xgrammar as a backend. We do not have a fallback in place to outlines as we did in V0. Other backends will come in a follow-up.
While one of the goals in V1 is to minimize or eliminate conflicts between features, this does not yet work with speculative decoding. The features were developed in parallel and we haven’t had a chance to consider the interoperability challenges in depth. This will also be considered as possible follow-up work.
Some key points of the current design include:
- Compilation of the grammar for a structured output request is done asynchronously and will not block the scheduler or any other requests from getting scheduled in the meantime.
- We keep a cache of compiled grammars for accelerating the case where the same grammar is used repeatedly.
- Advancing the FSM and calculating the next logits bitmask is done in the scheduler and then broadcasted to the GPU workers with the rest of the inputs already being sent.
- We arrange the bitmasks in a single tensor to be applied to the full batch of logits in a single operation.
There are several ideas on how this design might evolve. By using this as a functional starting point, we will be able to evaluate changes using benchmarks.