[RFC]: Improve guided decoding (logit_processor) APIs and performance.
Motivation.
Currently, guided decoding & logit processor API is incomplete has has several issues. The RFC is intended to bring up problems and solutions. Some of issues may have been already addressed and there are PRs out already.
There are 3 major issues.
- It is not supported from SamplingParamters
- It is not possible to support batch/async logit processing.
- Upon failures, engine will die.
Proposed Change.
API
guided decoding parameters are not supported with SamplingParams. It is addressed from https://github.com/vllm-project/vllm/pull/4130
Performance
Currently, logit processors APIs are applied row by row blocking (https://github.com/vllm-project/vllm/blob/246598a6b1e22616630b7f1bf11bd9bcb31dc860/vllm/model_executor/layers/logits_processor.py#L112). Instead, we can use parallel processing (e.g., ray or thread pool) to improve the logit processing performance. We are using this mechanism internally at Anyscale. We'd like to support this feature in OSS, and would like to improve logit processor API to support 1. async. 2. batching.
This requires logit processor to be
- stateful (to use a tool like Ray or thread pool). I think this PR https://github.com/vllm-project/vllm/pull/5329 is likely sufficient.
- async. We'd like to propose "prepare" API which can separate out compute_logits from preparing logits.
class LogitPostProcessor:
def initialize(self, logit_processor_config: LogitProcessorConfig):
"""Initialize the post processor. Post processor may have states
such as thread pool or Ray actors. It should be initialized
here.
"""
...
def prepare(
self,
seq_gruop_metadata_list: List[SequenceGroupMetadata]):
"""Asynchronously prepare logit masks."""
...
def apply(self, logits: torch.Tensor) -> torch.Tensor:
"""Apply the prepared masks to a given logits."""
...
# For each model, we will have
def compute_logits(...):
....
def prepare_logits(seq_group_metadata_list):
....
prepare and apply assume 1:1 calls. E.g., once prepare is called, apply has to be called before another prepare is called. I think it is the safe assumption. Alternatively, we can make prepare return a class, but that will make interface surface larger, so I don't prefer that solution (but I am open to hear feedback!)
This is the example usage of the API
# each model will have prepare_logits API
self.model.prepare_logits(seq_group_metadata_list)
hidden_states = model_executable(
input_ids=input_tokens,
positions=input_positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
**multi_modal_kwargs,
)
# Compute the logits. logit processors are applied here.
logits = self.model.compute_logits(hidden_states, sampling_metadata)
We are also considering to upstream Ray based batch processing implementation with lmformatenforcer.
Failure Handling
When using a stateful logit processor, it is possible requests are failed. For example, if we use Ray, Ray actors can die. Or there could be user's schema issue that cannot be caught ahead of time.
When it happens, we should fail the seq_group immediately. We will introduce a new status "FINISHED_INTERNAL_ERROR = enum.auto()" to https://github.com/vllm-project/vllm/blob/246598a6b1e22616630b7f1bf11bd9bcb31dc860/vllm/sequence.py#L42. If any logit processor is failed, we will mark the relevant seq_group as failed, and the request will be aborted.
Feedback Period.
No response
CC List.
cc @simon-mo @Yard1
Any Other Things.
No response
cc @njhill @br3no @mmoskal
I have a few questions:
It is not supported from SamplingParamters
Can you elaborate on why you think placing the guided decoding parameters in the SamplingParams is a good idea? As I commented in #4130, I think they conceptually overlap with the logits processors implementing the guided decoding, which are already in the SamplingParams.
This requires logit processor to be
- stateful (to use a tool like Ray or thread pool). ...
Do you maybe mean stateless? If not, what do you mean exactly?
Regarding the topic of statefulness: we probably don't want to limit ourselves to stateless logits processors. If we manage to make the API so that it is easy to implement stateful logits processors, we would already make things much better. E.g. I think that a very good thing to address would be to add infrastructure for pooling stateful objects and making it easy to define that one such object should not be shared across sequences and requests, or at least should be reset before being used.
Could you also please elaborate on the new LogitsPostProcessor API you propose? Is this the API to be implemented by logits processors? Or is this an API to be implemented by the models?
Are there maybe some type annotations missing for the return values of e.g. prepare? If this method does not return anything, this means the LogitsPostProcessor is stateful, right? Shouldn't we aim for a stateless design here, to make parallelization easier?
I might have misunderstood the proposal though. So, I'd be really happy if you could elaborate on it.
All in all, I would be very interested in improvements in this area, so I'm glad you're working on it!
Can you elaborate on why you think placing the guided decoding parameters in the SamplingParams is a good idea? As I commented in https://github.com/vllm-project/vllm/pull/4130, I think they conceptually overlap with the logits processors implementing the guided decoding, which are already in the SamplingParams.
It's like moving the functionality to the core API. Right now, it is implemented like an add-on (only working with OpenAI server), and it doesn't work with tools like https://github.com/anyscale/ray-llm (because we directly use the core API). It requires code that breaks the abstraction barrier (i.e., creating logit processor), and given the guided decoding is a core function, I feel like having the API in SamplingParams make sense.
Do you maybe mean stateless? If not, what do you mean exactly?
To improve time to prepare masks for json mode, we want to use parallel processing tools such as threadpool or ray. It requires the logit processor to be "stateful" because we don't want to recreate actors or threadpools everytime logit processos is requested (it should be created in __init__).
E.g. I think that a very good thing to address would be to add infrastructure for pooling stateful objects and making it easy to define that one such object should not be shared across sequences and requests, or at least should be reset before being used.
+1. I think it'd be an implementation of part 2.
Could you also please elaborate on the new LogitsPostProcessor API you propose? Is this the API to be implemented by logits processors? Or is this an API to be implemented by the models?
It will replace _apply_logit_processor private API inside logit_processor.py. Right now, we apply logit mask row by row. We instead 1. find the relevant logit processor created. 2. logit_processor.prepare(seq_group_metadata_list) -> logit_processor.apply(logits).
Are there maybe some type annotations missing for the return values of e.g. prepare? If this method does not return anything, this means the LogitsPostProcessor is stateful, right? Shouldn't we aim for a stateless design here, to make parallelization easier?
You are right the prep and apply is stateful. We can make it this way as well.
masks = self.model.prepare_logits(seq_group_metadata_list)
hidden_states = model_executable(
input_ids=input_tokens,
positions=input_positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
**multi_modal_kwargs,
)
# Compute the logits. logit processors are applied here.
logits = self.model.compute_logits(hidden_states, sampling_metadata, masks)
But I found it easier to just make it fully stateful.
Hope this clarifies the proposal a little bit!
We should make this work with the following RFCs
@NadavShmayo https://github.com/vllm-project/vllm/pull/4769 @mmoskal https://github.com/vllm-project/vllm/pull/4775 @mmoskal https://github.com/vllm-project/vllm/pull/2888 @maxdebayser @njhill https://github.com/vllm-project/vllm/pull/5329 @lynkz-matt-psaltis https://github.com/vllm-project/vllm/pull/5006
My initial thoughts;
- https://github.com/vllm-project/vllm/pull/4769 -> Plugin needs to change the API to accept stateful logit processor (@njhill already commented to the PR)
- https://github.com/vllm-project/vllm/pull/4775 -> Right now, I feel orthogonal (there are overlaps, but the PR mentioned APIs are supposed to be used "with" logit processor, and this RFC is logit processor specific. If we use thread pool or ray, we cannot prepare logits outside model_runner, so the API is limited).
- https://github.com/vllm-project/vllm/pull/5329 seems like a implementation of second part.
- https://github.com/vllm-project/vllm/pull/5006 seems orthogonal (existing logit processors will just work with current proposal. It is just prepare with no-op)
Some ideas:
- maybe
initialize()can be async? the reason is that we don't start scheduling sequences, where the processor is still initializing (in case it takes a few seconds) - add some sort of
free()API so resources can be freed
With an additional post-sampling callback, this would subsume my SequenceController #4775 :
def sampled(self, seq: 'Sequence', token_id: int,
logprobs: Dict[int, 'Logprob']) -> Tuple[int, List[int], bool]:
"""
Informs the controller a given token has been sampled.
Returns the number of tokens to backtrack, the tokens to append,
and whether to stop.
"""
if token_id == seq.eos_token_id:
return 0, [], True
return 0, [token_id], False
With an additional post-sampling callback, this would subsume my SequenceController https://github.com/vllm-project/vllm/pull/4775 :
I see. I found that API is limited for our particular use case because as you know it is applied after sampling is done (whereas we want to apply logit processor on final logits). It's great if we can subsume it.
add some sort of free() API so resources can be freed
I am open to it, but right now there's no specific use cases.
maybe initialize() can be async? the reason is that we don't start scheduling sequences, where the processor is still initializing (in case it takes a few seconds)
How is this guaranteed now?
@rkooo567 thanks, let me see if I understand it:
The idea is that the logits processors will be asked to prepare their masks asynchronously and in the meantime the model is going to be run. Once both are ready, the logits are computed by having the model call apply.
This means that the whole process needs to guarantee that there is one logits processor instance per request per sequence. Correct?
The implementation will need to be very careful to avoid contention issues.
Regarding the combination of this with the other PRs: I'm still struggling a bit to understand what general design we need. Let me explain:
The logits processors are now applied in the models; so the general signature of the operation is
compute_logits(hidden_states: Tensor, ...) -> Tensor
We want to support ff-tokens or backtracking (e.g. #4775). These things happen a few layers above the model and don't fit this API above.
So we're talking about different things in different abstraction layers at the same time.
Am I the only one? Is the design clear to you folks? If so, I would appreciate it a lot if someone could describe where which types of object would play which role.
@br3no One thing that took me a while to see is that there is only one LogitPostProcessor per LLMEngine - it handles logits for all sequences in the current batch.
There was some discussion of allowing a list of those, but IMHO it's easy to write a LogitPostProcessor that bundles an arbitrary number of ``LogitPostProcessor`s so I think there's no need to have a list of post processors in vLLM.
I'm the one asking for ff_tokens and backtracking, I think @rkooo567 is not doing this now.
@rkooo567 @simon-mo @mmoskal some additional thoughts after we talked offline yesterday:
It's a concern that the current support is kind of broken, it doesn't work for input batches or beam search due to the stateful/concurrency thing. So I wonder if we could prioritize some simpler immediate fixes for that along with the egregious performance overhead with json mode due to having to construct a new CFGuide instance every time. i.e. before the more significant rework to introduce batched application and the prepare step... WDYT?
A couple of other thoughts about the proposed interface:
- Why would we need an
initializemethod, couldn't a regular constructor be used for this? - I'm not sure that it's a good idea to expose
List[SequenceGroupMetadata]in this API ... I had assumedSequenceGroupMetadatais an internal datastructure that we want the freedom to change without breaking 3rd party LogitsProcessor impls. Probably should have some simpler dataclass or abstract class designed specifically for the API.
@mmoskal thanks for your answer! I also would like to support ff-tokens since I think this would contribute to alleviate the performance issues.
@njhill I’m not familiar with lm-format-enforcer, but for the Outlines processors now only the CFG one is problematic. The others are now stateless. Should we concentrate on a “fix” for the output_format: json issue? This would involve an object pool for the CFGGuide for that particular use case. Or am I missing other aspects here?
There was some discussion of allowing a list of those, but IMHO it's easy to write a LogitPostProcessor that bundles an arbitrary number of ``LogitPostProcessor`s so I think there's no need to have a list of post processors in vLLM.
I also agree with it. I have impression the current interface is a little over-designed with some vague implementation in mind. For ff-tokens and backtracking, I would like to see the implementation otherwise it is very difficult to design the interface (that's why we punted). I think the interface I propose here is not going to bother us getting there (logit processor API also feels like it is not very stable API yet, so we have time to iterate).
It's a concern that the current support is kind of broken, it doesn't work for input batches or beam search due to the stateful/concurrency thing. So I wonder if we could prioritize some simpler immediate fixes for that along with the egregious performance overhead with json mode due to having to construct a new CFGuide instance every time. i.e. before the more significant rework to introduce batched application and the prepare step... WDYT?
Does it mean supporting stateful logit processor first (meaning merging the open PR)? I am okay with this.
Why would we need an initialize method, couldn't a regular constructor be used for this?
I think regular constructor could work. The main reason was we need to pass the decode config to the logit processor, and since it is inside the model, the required change was big. I think constructor makes more sense actually.
I'm not sure that it's a good idea to expose List[SequenceGroupMetadata] in this API ... I had assumed SequenceGroupMetadata is an internal datastructure that we want the freedom to change without breaking 3rd party LogitsProcessor impls. Probably should have some simpler dataclass or abstract class designed specifically for the API.
Yeah it is a good point. for our internal impl, we just need seq_data, seq_ids, request_id, and sampling params.
I did a first pass on this in #6273
This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you!
Hi all, for those who are following this thread, I started benchmarking current performance for guided decoding in vLLM here #10046
This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you!