vllm icon indicating copy to clipboard operation
vllm copied to clipboard

[RFC][V1] `LogitsProcessor` interface

Open njhill opened this issue 10 months ago • 2 comments

Proposed abstraction for how to handle sampling parameters in relation to the persistent batch. This interface could then be used as an extension point for custom logits processors.

Key goals/ideas:

  • Logits processor implementations are configured globally, we won't support per-request
  • They apply at a batch level rather than per-request to allow for / encourage vectorized application
  • Each logits processor encapsulates its own state and is responsible for updating it as needed based on notification of persistent batch updates and new output tokens each step. This minimizes the number of times tensors need to be reconstructed and updated on the GPU.

To demonstrate the idea I've implemented LPs for min_tokens, logit_bias and min_p, but if we decide to go this route it should be straightforward to refactor the others similarly.

Note this is just to discuss the general approach - it could still be simplified/refined further.

class LogitsProcessor(ABC):
    @abstractmethod
    def apply(self, logits: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    @abstractmethod
    def update_states(
        self,
        batch_update: Optional[BatchUpdate] = None,
    ) -> None:
        """Called when there are new output tokens, prior
        to each forward pass.
        Args:
            batch_update is non-None iff there have been
            changes to the batch makeup.
        """
        raise NotImplementedError
        
@dataclasses.dataclass
class BatchUpdate:
    # Batch indices of any removed requests.
    removed: List[int]
    # (from, to) batch indices of any requests
    # moved within the batch.
    moved: List[Tuple[int, int]]
    # (index, params, output_tok_ids) for new
    # requests added to the batch.
    #TODO may need to include one or two other things here, like prompt token ids.
    added: List[Tuple[int, SamplingParams, List[int]]]
    # The current number of requests in the batch.
    batch_size: int

@WoosukKwon @AlpinDale @houseroad

njhill avatar Feb 16 '25 17:02 njhill

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

github-actions[bot] avatar Feb 16 '25 17:02 github-actions[bot]

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @njhill.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Feb 17 '25 01:02 mergify[bot]

@njhill shouldn't BatchUpdate also include the tokens that were sampled for each current sequence? it's not needed for the min_p/min_tokens/logit_bias ones you implemented, but would be needed for anything more complicated?

mmoskal avatar Feb 19 '25 22:02 mmoskal

@njhill shouldn't BatchUpdate also include the tokens that were sampled for each current sequence? it's not needed for the min_p/min_tokens/logit_bias ones you implemented, but would be needed for anything more complicated?

So this is not necessarily the final state, but currently this it's handled via the list that's passed in the added requests. This is the list that is updated in-place with new tokens and so the impl can hang on to that if it needs to know them, and check it each time it's called. The idea is that BatchUpdate will only be present if reqs have been added or removed from the batch, otherwise this list of output ids can be checked for all of the current requests that the LP cares about. Of course this will need to be clearly documented if it remains the way that it's done :)

There will definitely need to be changes to this regardless.. like adding the prompt tokens and we could possibly include a pointer to on-device tensor of just the newly generated tokens that will be there anyhow, since it might be faster for the LP to use that directly (rather than multiple of them copying these tokens back to the GPU).

You can see an example of this for the min_tokens impl in this PR, which uses the length of that list.

njhill avatar Feb 19 '25 22:02 njhill

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @njhill.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Mar 21 '25 11:03 mergify[bot]

In the V1 interface, I'm curious if the following code can still be used to customize the Logits processor for a single batch?

outputs = llm.generate(
    prompts,
    SamplingParams(
        temperature=temperature,
        max_tokens=1,
        allowed_token_ids=accepted_ids
    ),
)

And if not, is there any alternative solution?

Thank you for your time.

jiwangyihao avatar Mar 25 '25 02:03 jiwangyihao

You can set the guided decoding backend to "guidance" and use the following grammar, assuming your tokens are 74, 837, 1010, 1011, 1012, 1013:

start: ( <[74]> | <[837]> | <[1010-1013]> )*

for details see https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens

Make sure you have llguidance at least 0.7.10 (should be ready in 20 min); 0.7.9 as used in vllm would reject this grammar in front-end.

Of course you can also use strings in your grammar, not tokens, but then we will allow any tokenization of the given string, which may or may not be what you want.

mmoskal avatar Mar 25 '25 22:03 mmoskal

If I want to force stop thinking in an R1-like model (e.g., when prompt + outputs > 8192, force generate </think>) how can I get prompt_length for single request? Thank you for your time.

wlhgtc avatar Mar 26 '25 10:03 wlhgtc

I've just rebased this PR, but haven't yet addressed some changes made since it was originally created:

  • The TPU impl has separate sampling metadata handling, we need to see how this will work with that.
  • There is a check for whether spec decode is supported based on the state of the input batch. Now that the input batch will change to contain LPs instead of the "raw" sampling param data, this check will need to be adjusted.
  • A out-of-vocab token check was added to the logit bias impl in the main branch. The re-implemented logit bias logic in this PR doesn't include that, but we should move that check to be part of the request validation anyhow (as I've now commented on that PR).

njhill avatar Apr 15 '25 16:04 njhill

jfyi I am picking up this work, see here #16728

afeldman-nm avatar Apr 16 '25 15:04 afeldman-nm

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @njhill.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Apr 18 '25 08:04 mergify[bot]

  • was added

@njhill see my comment here

https://github.com/vllm-project/vllm/pull/16529#discussion_r2075990154

I think the validation inside of the logit bias logits processor is redundant

afeldman-nm avatar May 06 '25 17:05 afeldman-nm

I assume no in-place update operations are going to be allowed in order to not enforce any priority on the order on which logits processor in logits_procs are applied right?

@NickLucche sorry for very delayed response, missed your question originally. I think we'd preferably do the logits updates in-place, with some standard order for the built in ones and order of custom ones based on how they are configured.

njhill avatar May 07 '25 18:05 njhill

@njhill Can we close this PR?

ywang96 avatar Oct 07 '25 08:10 ywang96