[Dynamic Spec Decoding] Auto-disable by the running queue size
The first PR for #4565.
This PR adds an auto-disable mechanism to speculative decoding. Specifically, we allow users to set a threshold, in terms of the number of requests in the current running queue, to disable speculative decoding for new incoming requests.
Example usage: --speculative-disable-queue-size 4. This means to disable speculative decoding for the current batch of requests if running queue has more than 4 requests.
cc @cadedaniel @LiuXiaoxuanPKU @leiwen83
BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE
PR Checklist (Click to Expand)
Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.
PR Title and Classification
Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:
[Bugfix]for bug fixes.[CI/Build]for build or continuous integration improvements.[Doc]for documentation fixes and improvements.[Model]for adding a new model or improving an existing model. Model name should appear in the title.[Frontend]For changes on the vLLM frontend (e.g., OpenAI API server,LLMclass, etc.)[Kernel]for changes affecting CUDA kernels or other compute kernels.[Core]for changes in the core vLLM logic (e.g.,LLMEngine,AsyncLLMEngine,Scheduler, etc.)[Hardware][Vendor]for hardware-specific changes. Vendor name should appear in the prefix (e.g.,[Hardware][AMD]).[Misc]for PRs that do not fit the above categories. Please use this sparingly.
Note: If the PR spans more than one category, please include all relevant prefixes.
Code Quality
The PR need to meet the following code quality standards:
- We adhere to Google Python style guide and Google C++ style guide.
- Pass all linter checks. Please use
format.shto format your code. - The code need to be well-documented to ensure future contributors can easily understand the code.
- Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
- Please add documentation to
docs/source/if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.
Notes for Large Changes
Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.
What to Expect for the Reviews
The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:
- After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
- After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
- After the review, the reviewer will put an
action-requiredlabel on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR. - Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.
Thank You
Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!
QQ: can you also update the benchmark result?
The current speculative decoding performance isn't good enough due to the lack of bonus token, so the benchmark result may be less meaningful. We should perform another round of benchmarking later once all missing pieces are in place.
About the lack of bonus token, did you mean that the resampling step is now completely missing, or just that for the all-accept case the bonus token is missing? If it's the latter, then the case is minority and wouldn't affect benchmark too much
It's discussed in #4212 and should be the latter case you mentioned. On the other hand, it does affect the performance a lot in certain cases. First, without a bonus token, there's no guarantee that each step will generate at least one token. Second, inter-token-latency (ITL) can be calculated by decoding-step-latency / accepted-tokens-in-this-step, so assuming the number of speculative tokens is 5 and the acceptance rate without the bonus token is 60%, then we accept 3 tokens in a step, meaning that ITL=L/3. However, with a bonus token, it will be ITL=L/4, which is 33% speedup.
Updated on reject sampler (cc @cadedaniel for reviewing these API changes of reject sampler):
- Enable bonus token for PLD.
- Disable strict_mode using environment variable
VLLM_DISABLE_REJECT_SAMPLING_STRICT_MODE.
With these updates, the performance should be acceptable, so I'll try to benchmark on A100 when I got a chance, but this should not be a blocker of this PR.
Disable strict_mode using environment variable VLLM_DISABLE_REJECT_SAMPLING_STRICT_MODE.
This is not necessary; we can simply set strict_mode to False (we only had it in for development correctness, now we can disable it outside of unit tests).
It's discussed in #4212 and should be the latter case you mentioned. On the other hand, it does affect the performance a lot in certain cases. First, without a bonus token, there's no guarantee that each step will generate at least one token. Second, inter-token-latency (ITL) can be calculated by
decoding-step-latency / accepted-tokens-in-this-step, so assuming the number of speculative tokens is 5 and the acceptance rate without the bonus token is 60%, then we accept 3 tokens in a step, meaning thatITL=L/3. However, with a bonus token, it will beITL=L/4, which is 33% speedup.
Thanks for the clarification! I'm very interested in the implementation details then. Actually I also implemented speculative decoding system, and ran into the corner case of how to deal with bonus token in the the all-accepted case. I moved this bonus token sampling to the next round of the speculative iteration, due to the reason mentioned in #4212 : "KV is not generated for the draft model", thus some extra prefilling might be needed for the next round. [It just occurs to me now that the in the inter-token-latency (ITL) calculation above, this extra draft sampling (the extra prefilling) should also be considered, leading to an effective speculation length to be L+1, in this case 5 + 1 =6]. But I look forward to this PR to see how much this bonus token improves the performance.
tests failing unfortunately
tests failing unfortunately
The failed test seems flaky. I found the same 2 spec decode tests failed in the main branch as well. Should retry all failed tests.
seems since https://github.com/vllm-project/vllm/pull/4551 was merged it caused that test to fail on the main branch. investigating..
There's an inefficient allocation during spec decode which can cause OOM when paired with a large batch size. I lowered the batch size in that test, it passes locally for me.
root cause fix is here: https://github.com/vllm-project/vllm/pull/4672/files. introduced because ngram tests weren't actually running (fixed in https://github.com/vllm-project/vllm/pull/4551), but we didn't merge main before and thus missed logprobs x ngram combination 🫠
Thanks for fixing that! Could you help retry the failed tests again?
Retried
retry again?
@richardliaw seems it’s waiting for compute. I found this out by opening the buildkite link.
retry again?
It is retrying automatically but waiting for the resource...
It's discussed in #4212 and should be the latter case you mentioned. On the other hand, it does affect the performance a lot in certain cases. First, without a bonus token, there's no guarantee that each step will generate at least one token. Second, inter-token-latency (ITL) can be calculated by
decoding-step-latency / accepted-tokens-in-this-step, so assuming the number of speculative tokens is 5 and the acceptance rate without the bonus token is 60%, then we accept 3 tokens in a step, meaning thatITL=L/3. However, with a bonus token, it will beITL=L/4, which is 33% speedup.
Hi @comaniac, I reconsidered this bonus token in spec-dec, and find that it is almost completely equivalent to speculation-length plus one. I'd like to share it with you for your reference. Suppose, speculation-length is 5 (L=5) as you assumed above, with this bonus token to be sampled, the 5th drafted token would need to be fed into the large target model (along with the 0-4th tokens), and the 5th token would need to be fed into the draft model to compute the kv_cache, and the 6th token (bonus token) will be used as the input in the next round.
In contrast, if the bonus token is not sampled, then only up to the 4th tokens would need to be fed into the draft model, and the target model. The 5th token is used to index logits of the 4th token's target model output, and then accepted or resampled. Then the accepted 5th token is used as the input in the next round.
So in the second scenario L=5 (0-4 th tokens are fed into the target model); in the first, L=6 (0-5th tokens are fed into the target model) regardless of whether the request has been all accepted to the end or not.
But you did mentioned that
First, without a bonus token, there's no guarantee that each step will generate at least one token
In the speculative decoding process I described above, it doesn't have this issue, since it starts with an input_token, as mentioned above, either from the last round or the prefill. So this input_token is fed into both target and draft in the current round, and is guaranteed to generate at least one token, due to the resampling of the residual distribution. So there might be a possibilty that your implementation is a little different than the process discribed above.
As you pointed out, my previous description about generating at least one token is incorrect. The one that guarantees at least one token to be generated in each step is called "recover" token, which is different from the bonus token. Bonus token is only accepted when all speculative tokens are accepted.
So in the second scenario L=5 (0-4 th tokens are fed into the target model); in the first, L=6 (0-5th tokens are fed into the target model) regardless of whether the request has been all accepted to the end or not.
Yes, we always feed L tokens proposed by the draft model to the target model whatever the proposed tokens are accepted or not. Let me try to summarize the steps in my words:
- [Propose] Run draft model for L steps to generate L speculative tokens.
- [Verify] Run target model with the L speculative tokens at once to generate kv-cache and logprobs. Because of that, the (L+1)th token is automatically generated by the target model.
- [Sample] Decide how many speculative tokens should be accepted.
- If all tokens are accepted, we feed the bonus token generated by the target model to the draft model to update its kv-cache.
Thus, the cost of bonus token is step 4, which updates the kv-cache of the draft model. And this is the feature vLLM currently missing (but we will add it soon). Fortunately, some speculative decoding methods such as prompt lookup decoding (i.e., n-gram) don't need a draft model, so bonus token is completely free.
As you pointed out, my previous description about generating at least one token is incorrect. The one that guarantees at least one token to be generated in each step is called "recover" token, which is different from the bonus token. Bonus token is only accepted when all speculative tokens are accepted.
So in the second scenario L=5 (0-4 th tokens are fed into the target model); in the first, L=6 (0-5th tokens are fed into the target model) regardless of whether the request has been all accepted to the end or not.
Yes, we always feed L tokens proposed by the draft model to the target model whatever the proposed tokens are accepted or not. Let me try to summarize the steps in my words:
- [Propose] Run draft model for L steps to generate L speculative tokens.
- [Verify] Run target model with the L speculative tokens at once to generate kv-cache and logprobs. Because of that, the (L+1)th token is automatically generated by the target model.
- [Sample] Decide how many speculative tokens should be accepted.
- If all tokens are accepted, we feed the bonus token generated by the target model to the draft model to update its kv-cache.
Thus, the cost of bonus token is step 4, which updates the kv-cache of the draft model. And this is the feature vLLM currently missing (but we will add it soon). Fortunately, some speculative decoding methods such as prompt lookup decoding (i.e., n-gram) don't need a draft model, so bonus token is completely free.
I see. Thanks, this description is much more clear. So, in your description here, At step 1, 1st, 2nd, ..., L-th token are generated, which means 0th, 1st, ..., (L-1)-th tokens are fed into the draft model. (The 0th token is the last generated token from the last round. ) At step 2, the 0th, 1st, ..., (L-1)-th tokens are fed into the target model. Here, still totally L tokens have been processed by the target model, but not the same L tokens as you mentioned. Then whether or not feeding the L-th token to the target model is the key difference. If yes, then for the all-accepted candidate sequence, (L+1) draft model inference calls occur. If no, then only L draft model calls occur. The target model input-size is (L+1) tokens for both situations. I see. This is equivalently to adaptively increasing the speculation length by 1 for the all-accepted sequence, thus L+1 draft calls. But for those not all-accepted, only L draft calls, something like a dynamic speculation length: https://github.com/vllm-project/vllm/issues/4565. Yeah, it will be benificial indeed.