Fix verify tokens with the correct bonus token
The real bonus token should be the first unmatched token. For example, the draft_token_ids is [1, 2, 3, 5, 7], and the target_token_ids is [1, 2, 3, 4, 6, 8]. Then, the matched token should be [1, 2, 3], and the bonus token should be [4] because the target model will output [4] based on the input [1, 2, 3] in the next round generation. So we can select the bonus token [4] in this round without any precision regression.
It will increase the performance from 89 tokens/s to 110 tokens/s in typical_acceptance_sampler in A100 single card with (num_speculative_tokens=2, max_num_seqs=1, model="meta-llama/Llama-2-7b-chat-hf", speculative_model="Felladrin/Llama-68M-Chat-v1"). The outputs are exactly the same before and after my changes.
Do you mind review this PR? @cadedaniel cc @LiuXiaoxuanPKU
👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.
Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.
To run CI, PR reviewers can do one of these:
- Add
readylabel to the PR - Enable auto-merge.
🚀
Hey, thanks for the interest! I want to align some definition here:
- In vllm,
bonus_token_idsis defined as
The "bonus" token ids that are accepted iff all speculative tokens in a sequence are accepted.
That being said, if the draft_token_ids is [1, 2, 3, 5, 7], and the target_token_ids is [1, 2, 3, 5, 7, 8]. Then the bonus_token_ids for this record is [8]. Notice here, all proposed tokens 1, 2, 3, 5, 7 are accepted.
2. For the example you gave, we say 4 is a recovered_token. To get the recovered token, instead of getting it from the target model directly. vllm samples the recovered token from a new distribution as shown here. (Minor thing, for greedy decoding, yeah the sampled token is the same as the target token, but it might be different for standard sampling.) We implemented it to strictly follow the paper.
Please let me know if there is any confusion here! Sorry for different terms here, we might simplify it in the future.
After second thoughts, I feel this is a good optimization. We can actually skip rejection sampling in the greedy decoding case, which explains the speedup you get. I also suggestion trying the flashinfer backend and you should also see good speedup.
After second thoughts, I feel this is a good optimization. We can actually skip rejection sampling in the greedy decoding case, which explains the speedup you get. I also suggestion trying the flashinfer backend and you should also see good speedup.
Thx Xiaoxuan, so do you think this optimization idea is qualified to merge to vllm. We treat this as a platform independent optimization(orthogonal w/ backend optimizations like flashinfer), which can benefit other device backends like CPU/XPU, and we see similar performance issue.
This change is to align the verify token function to the transformers speculative sampling algorithm, it always selectes the next sample token from the target model.
This change is to align the verify token function to the transformers speculative sampling algorithm, it always selectes the next sample token from the target model.
From reading the code, it seems it adjusts the distribution and resamples (line 4195 - 4203)?
After second thoughts, I feel this is a good optimization. We can actually skip rejection sampling in the greedy decoding case, which explains the speedup you get. I also suggestion trying the flashinfer backend and you should also see good speedup.
Thx Xiaoxuan, so do you think this optimization idea is qualified to merge to vllm. We treat this as a platform independent optimization(orthogonal w/ backend optimizations like flashinfer), which can benefit other device backends like CPU/XPU, and we see similar performance issue.
Could you double check the correctness here? If the optimization can pass the rejection sampling tests, yeah happy to review and get it in.
From reading the code, it seems it adjusts the distribution and resamples (line 4195 - 4203)?
Yes, but the point is the output will always contain the first unmatched token p_n_plus_1 (which is the next sample token from the target model).
Could you double check the correctness here? If the optimization can pass the rejection sampling tests, yeah happy to review and get it in.
The error comes from sampling. We can not guarantee the output will be all matched even if the target model is the same as the draft model because sampling will introduce random factors. So I disabled sampling by setting temperature=0 so we can make sure all tokens will be matched.
The error comes from sampling. We can not guarantee the output will be all matched even if the target model is the same as the draft model because sampling will introduce random factors. So I disabled sampling by setting
temperature=0so we can make sure all tokens will be matched.
This is the intention because you cannot force users to set temperature=0. That's why @LiuXiaoxuanPKU suggested we could bypass rejected sampling when temperature=0, but we cannot remove rejected sampling.
This is the intention because you cannot force users to set temperature=0. That's why @LiuXiaoxuanPKU suggested we could bypass rejected sampling when temperature=0, but we cannot remove rejected sampling.
We cannot expect the output length of speculative decoding to be a fixed number if sampling is applied, it could be in [1, num_speculative_tokens] in a single step. How about changing the test to verify the output token number in a reasonable interval instead of finishing after 2 steps?
OK, I removed temperature=0 but changed the max_tokens and num_speculative_tokens to avoid tests failed by sampling. Please let me know your opinion about this test.
I'm confused. It seems this PR removes rejection sampling. Then how do you do speculative decoding with temperature != 0?
I'm confused. It seems this PR removes rejection sampling. Then how do you do speculative decoding with temperature != 0?
I didn't remove rejection sampling, just removed recovered_token since it is not needed. The newly defined bonus_token covers the recovered_token case. Besides, I kept the temperature in default value in the test.
The recovered token should not be from the target token ids if temperature != 0. It should be sampled from a new distribution. Correct me if I misunderstand anything.
The recovered token should not be from the target token ids if temperature != 0. It should be sampled from a new distribution. Correct me if I misunderstand anything.
I see your point, the recovered_token is selected from a new distribution, but I just select it from target_token. It is more convenient and also makes sense.
I know adjusting the distribution of target_prob is from the original paper, but I didn't see any advantages compared to just selecting from target_token_ids, and it also introduces a customized multinomial function and some overheads.
Please let me know if you want me to revert it to keep the original recovered_token from the adjusted target probs, thanks.
The recovered token should not be from the target token ids if temperature != 0. It should be sampled from a new distribution. Correct me if I misunderstand anything.
I see your point, the
recovered_tokenis selected from a new distribution, but I just select it from target_token. It is more convenient and also makes sense.I know adjusting the distribution of
target_probis from the original paper, but I didn't see any advantages compared to just selecting fromtarget_token_ids, and it also introduces a customized multinomial function and some overheads.Please let me know if you want me to revert it to keep the original
recovered_tokenfrom the adjusted target probs, thanks.
I don't think change original paper algorithm is a good idea without data proving, and I don't think change behavior is the target of this PR. This PR's target is performance optimization. @jiqing-feng, pls only optimize performance for temperature == 0, and don't change the logic of others. You can submit another PR or issue if you wanna discuss it, keeping the PR only for one thing is better for cognition burden and make things move fast forward.
I have reverted unnecessary changes. Now, the rejection sampling exactly follows the paper and is the same as transformers integration.
The recovered token should not be from the target token ids if temperature != 0. It should be sampled from a new distribution. Correct me if I misunderstand anything.
Yes, you were right. I have fixed the recovered token by selecting it from the new distribution (torch.clamp(target_prob-draft_prob), min=0). Please take a review. Thx.
Hi @LiuXiaoxuanPKU . I checked the rejection sampler codes in detail and found there is no need to change it because you can get the correct recovered token ids. Only 1 thing: The _multinomial function just selects the tokens with the largest probability which is different with torch.multinomial
I am okay with the little difference btw vllm and the original paper because I cannot get speed-up on rejection sampling with this PR, but it could get significant speed-up on typical acceptance sampler. So I opened another PR to only change typical acceptance sampler, see #8562 . Please take a review, thx.