[RFC][V1] `LogitsProcessor` interface
Proposed abstraction for how to handle sampling parameters in relation to the persistent batch. This interface could then be used as an extension point for custom logits processors.
Key goals/ideas:
- Logits processor implementations are configured globally, we won't support per-request
- They apply at a batch level rather than per-request to allow for / encourage vectorized application
- Each logits processor encapsulates its own state and is responsible for updating it as needed based on notification of persistent batch updates and new output tokens each step. This minimizes the number of times tensors need to be reconstructed and updated on the GPU.
To demonstrate the idea I've implemented LPs for min_tokens, logit_bias and min_p, but if we decide to go this route it should be straightforward to refactor the others similarly.
Note this is just to discuss the general approach - it could still be simplified/refined further.
class LogitsProcessor(ABC):
@abstractmethod
def apply(self, logits: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
@abstractmethod
def update_states(
self,
batch_update: Optional[BatchUpdate] = None,
) -> None:
"""Called when there are new output tokens, prior
to each forward pass.
Args:
batch_update is non-None iff there have been
changes to the batch makeup.
"""
raise NotImplementedError
@dataclasses.dataclass
class BatchUpdate:
# Batch indices of any removed requests.
removed: List[int]
# (from, to) batch indices of any requests
# moved within the batch.
moved: List[Tuple[int, int]]
# (index, params, output_tok_ids) for new
# requests added to the batch.
#TODO may need to include one or two other things here, like prompt token ids.
added: List[Tuple[int, SamplingParams, List[int]]]
# The current number of requests in the batch.
batch_size: int
@WoosukKwon @AlpinDale @houseroad
👋 Hi! Thank you for contributing to the vLLM project.
💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.
Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.
To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.
🚀
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @njhill.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
@njhill shouldn't BatchUpdate also include the tokens that were sampled for each current sequence? it's not needed for the min_p/min_tokens/logit_bias ones you implemented, but would be needed for anything more complicated?
@njhill shouldn't BatchUpdate also include the tokens that were sampled for each current sequence? it's not needed for the min_p/min_tokens/logit_bias ones you implemented, but would be needed for anything more complicated?
So this is not necessarily the final state, but currently this it's handled via the list that's passed in the added requests. This is the list that is updated in-place with new tokens and so the impl can hang on to that if it needs to know them, and check it each time it's called. The idea is that BatchUpdate will only be present if reqs have been added or removed from the batch, otherwise this list of output ids can be checked for all of the current requests that the LP cares about. Of course this will need to be clearly documented if it remains the way that it's done :)
There will definitely need to be changes to this regardless.. like adding the prompt tokens and we could possibly include a pointer to on-device tensor of just the newly generated tokens that will be there anyhow, since it might be faster for the LP to use that directly (rather than multiple of them copying these tokens back to the GPU).
You can see an example of this for the min_tokens impl in this PR, which uses the length of that list.
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @njhill.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
In the V1 interface, I'm curious if the following code can still be used to customize the Logits processor for a single batch?
outputs = llm.generate(
prompts,
SamplingParams(
temperature=temperature,
max_tokens=1,
allowed_token_ids=accepted_ids
),
)
And if not, is there any alternative solution?
Thank you for your time.
You can set the guided decoding backend to "guidance" and use the following grammar, assuming your tokens are 74, 837, 1010, 1011, 1012, 1013:
start: ( <[74]> | <[837]> | <[1010-1013]> )*
for details see https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
Make sure you have llguidance at least 0.7.10 (should be ready in 20 min); 0.7.9 as used in vllm would reject this grammar in front-end.
Of course you can also use strings in your grammar, not tokens, but then we will allow any tokenization of the given string, which may or may not be what you want.
If I want to force stop thinking in an R1-like model (e.g., when prompt + outputs > 8192, force generate </think>)
how can I get prompt_length for single request?
Thank you for your time.
I've just rebased this PR, but haven't yet addressed some changes made since it was originally created:
- The TPU impl has separate sampling metadata handling, we need to see how this will work with that.
- There is a check for whether spec decode is supported based on the state of the input batch. Now that the input batch will change to contain LPs instead of the "raw" sampling param data, this check will need to be adjusted.
- A out-of-vocab token check was added to the logit bias impl in the main branch. The re-implemented logit bias logic in this PR doesn't include that, but we should move that check to be part of the request validation anyhow (as I've now commented on that PR).
jfyi I am picking up this work, see here #16728
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @njhill.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
- was added
@njhill see my comment here
https://github.com/vllm-project/vllm/pull/16529#discussion_r2075990154
I think the validation inside of the logit bias logits processor is redundant
I assume no in-place update operations are going to be allowed in order to not enforce any priority on the order on which logits processor in
logits_procsare applied right?
@NickLucche sorry for very delayed response, missed your question originally. I think we'd preferably do the logits updates in-place, with some standard order for the built in ones and order of custom ones based on how they are configured.
@njhill Can we close this PR?