vllm
vllm copied to clipboard
[V1][Spec Decode] Ngram Spec Decode
This PR tries to add ngram spec decode to V1. Design doc: here. Major changes:
- Since we only implement the ngram spec decode, we did not add another scheduler for running the drafting method. We always check if we need to do ngram lookup before calling the scheduler.
- Add a new field
_spec_token_idsinRequestto track speculated tokens. - Changes to
model_runner: 3.1 Change the_prepare_inputto also return the logits of speculated tokens. 3.2 Change the_prepare_inputto add speculated tokens as input tokens. 3.3 Change theexecute_modelto generate multiple tokens per call. Concretely, it will add more than one tokens toinput_batchandreq_state. - We only perform spec decode for requests in the running queue.
- We only support greedy decoding for now.
What is missing
- [x] Change scheduling to only propose tokens for decoding requests.
- [x] Stop checking for spec decode, where mutiple tokens are generated in a single step.
- [x] For the ngram lookup logic, currently I just append dummy tokens directly instead of performing the lookup. We can move v0's lookup logic here.
- [x] Check the correctness of this PR with chunked prefill. <-- We only perform spec decode in the decoding phase.
- [ ] More end to end tests & Style.
Tasks out of the scope of this PR
- Optimize the performance of ngram lookup.
- Support non-greedy decoding.
- Add other spec decode methods.
[Update] I will move the following two features into following PRs:
- Guarantee the correctness of prefix caching + spec decode, because it will involve changing the behavior of kv cache manager @comaniac.
- Change the scheduling policy to guarantee that at least one token is scheduled for each request. Separate this because it will touch the scheduling code and needs more careful thought/test.
Minor: There is a minimal example/test in tests/v1/e2e/test_basic_specdecode.py. You can check it for the current use and check correctness with pytest -s tests/v1/e2e/test_basic_specdecode.py.
Followup work:
- benchmark the flashinfer rejection sampler
- ngram kernel
- proposer stop checking
- KV cache based draft model
👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.
Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.
To run CI, PR reviewers can do one of these:
- Add
readylabel to the PR - Enable auto-merge.
🚀
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @LiuXiaoxuanPKU.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
Surely late here, but why is a speculative decoding-aware scheduler needed? Wouldn't it be possible to just assume multi-token generation per-step as default?
Surely late here, but why is a speculative decoding-aware scheduler needed? Wouldn't it be possible to just assume multi-token generation per-step as default?
Because the scheduler has to know how many kv-cache slots are needed for each request. We use lookahead slots in v0 that always allocates k lookahead slots for each request when spec decode is enabled. However, it's inefficient when we don't have k spec tokens for every request. This may happen in the following examples:
- N-gram won't propose any tokens if failed to find a match.
- The draft model generates EOS.
- Insufficient kv-cache slots.
- In dynamic speculative decoding, we control "k" based on the current traffic.
So in this design for v1, we first get the spec tokens, and let the target model scheduler allocate the exact number of slots accordingly.
I see, thanks a lot for elaborating @comaniac!
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @LiuXiaoxuanPKU.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @LiuXiaoxuanPKU.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @LiuXiaoxuanPKU.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
A problem with making a perfect PR is that it takes longer time and cannot split works to others. I actually don't think the tech debt in v0 came from multiple PRs. The tech debt comes from fast delivery but workaround designs. So we should really have a great design doc that considers everything and aligns everyone. After that we just follow the doc to send PRs
@njhill Thanks for the detailed review! I do agree that this PR currently needs more iterations and should not break/degrade any case when spec decoding is unused.
Tasks out of the scope of this PR Optimize the performance of ngram lookup. IMHO we should take a different approach to v0 and not merge this until it's optimized, given the whole purpose of it is performance improvement
I was actually the one who suggested keeping this optimization out of the current PR. Let me provide the context:
I was planning to optimize the KMP search algorithm (used in the N-gram proposer) using Numba. My profiling shows that @numba.njit significantly improves performance - reducing search time from 150 us to 6 us per request for 4K context length. However, to achieve optimal performance with Numba, the inputs need to be Numpy arrays. This would require refactoring all_token_ids to use a Numpy-backed list. Given the complexity of this change, I decided it would be better to handle it in a separate PR rather than including it here.
I was planning to optimize the KMP search algorithm (used in the N-gram proposer) using Numba. My profiling shows that
@numba.njitsignificantly improves performance - reducing search time from 150 us to 6 us per request for 4K context length. However, to achieve optimal performance with Numba, the inputs need to be Numpy arrays. This would require refactoringall_token_idsto use a Numpy-backed list. Given the complexity of this change, I decided it would be better to handle it in a separate PR rather than including it here.
Thanks @WoosukKwon, this makes sense!
A problem with making a perfect PR is that it takes longer time and cannot split works to others. I actually don't think the tech debt in v0 came from multiple PRs. The tech debt comes from fast delivery but workaround designs. So we should really have a great design doc that considers everything and aligns everyone. After that we just follow the doc to send PRs
@comaniac yes I strongly agree that mutliple/incremental PRs are good, but imo each of these PRs should be streamlined/simplified before merging, or else what's added on top by others can compound the inefficiencies/complexity and make this harder to address later. I don't think that matters so much where the to-do optimization is known/isolated like what @WoosukKwon is referring to above, am referring more to the general structure. It's also an option for multiple people to help with the same PR.
Hi folks, thanks for the detailed review!! I make the following changes to this PR:
- Stop checking: Change the stop checking logic as discussed with @WoosukKwon. I will not crop request now. We only append a token_id to the output if it passes the stop checking. The _check_stop() logic is kept as before. Hope this change can resolve many issues above.
- As requested by @WoosukKwon, I simplify the logic in counting the number of cached blocks: I use
cached_block_numin the kv cache manager to record the number of cached blocks for each request. That field is only used by spec decode now cc @comaniac.
Some changes I did not make:
- I still keep spec_token_ids in the Request object because I find it a bit hard to keep this field only in the scheduler.
- I have not supported log_probs and random sampling yet. yeah we plan to support both in the future, but I prefer to delay those in following PRs.
Please take another round of review @WoosukKwon @njhill and let me know your thoughts, thanks!!
PS: There are still 1-2 minor format comments, will fix it ASAP.
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @LiuXiaoxuanPKU.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @LiuXiaoxuanPKU.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
When is this PR expected to be completed? We are eagerly waiting to use the N-gram feature. Keep it up!
Hi folks @WoosukKwon @njhill, I just updated this PR:
- vectorize the rejection sampler.
- vectorize the spec tokens' input preparation .
- modify the logic to update generated token ids in the input_batch here.
I feel 3 is still a bit redundant, let know your thoughts.
Some other topics:
- I will also change the default value for
spec_token_idsinSamplingMetadatato()as suggested by @njhill. - I think ngram can work with random sampling, but not sure about the performance here. The difference is that we sample the output tokens instead of using argmax. The proposing logic is the same. So we compare the ngram lookup results with sampled output. It might affect the token acceptance rate, but it should be runnable.
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @LiuXiaoxuanPKU.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
@LiuXiaoxuanPKU Could you please provide performance benchmarks?
- Main branch
- This PR without spec decoding
- This PR with spec decoding + low QPS
- This PR with spec decoding + high QPS
While I know that the Ngram proposer needs to be optimized, it'd be nice to check that 1 & 2 show the same perf, and we get a decent speedup on 3.
Some benchmark results: Model: meta-llama/Meta-Llama-3-8B-Instruct Hardware: 1xH100 Number of requests: 500 for QPS 1, 10, 1000 for QPS 20 export VLLM_USE_V1=1
Median TTFT/TPOT
The median TTFT/TPOT looks ok for w/o SD ngram branch and main branch, but it does slow down the main a bit. My guess is because of the change for delayed CPU <-> GPU synchronization.
P99 TTFT/TPOT
Detailed setting in doc.
@LiuXiaoxuanPKU Could you take a look at the failed tests?
I'll approve the PR once the tests are green!
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @LiuXiaoxuanPKU.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
@LiuXiaoxuanPKU Just wanted to double check. Is the PR ready for merge?
@WoosukKwon Hi, vllm nightly wheel doesn't have v1/spec_decode directory.
raceback (most recent call last):
File "/data/lijinghui/uv_projects/.venv/lib/python3.12/site-packages/gunicorn/arbiter.py", line 608, in spawn_worker
worker.init_process()
File "/data/lijinghui/uv_projects/.venv/lib/python3.12/site-packages/uvicorn/workers.py", line 75, in init_process
super().init_process()
File "/data/lijinghui/uv_projects/.venv/lib/python3.12/site-packages/gunicorn/workers/base.py", line 135, in init_process
self.load_wsgi()
File "/data/lijinghui/uv_projects/.venv/lib/python3.12/site-packages/gunicorn/workers/base.py", line 147, in load_wsgi
self.wsgi = self.app.wsgi()
^^^^^^^^^^^^^^^
File "/data/lijinghui/uv_projects/.venv/lib/python3.12/site-packages/gunicorn/app/base.py", line 66, in wsgi
self.callable = self.load()
^^^^^^^^^^^
File "/data/lijinghui/uv_projects/.venv/lib/python3.12/site-packages/gunicorn/app/wsgiapp.py", line 57, in load
return self.load_wsgiapp()
^^^^^^^^^^^^^^^^^^^
File "/data/lijinghui/uv_projects/.venv/lib/python3.12/site-packages/gunicorn/app/wsgiapp.py", line 47, in load_wsgiapp
return util.import_app(self.app_uri)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/lijinghui/uv_projects/.venv/lib/python3.12/site-packages/gunicorn/util.py", line 370, in import_app
mod = importlib.import_module(module)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/mosh/.local/share/uv/python/cpython-3.12.7-linux-x86_64-gnu/lib/python3.12/importlib/__init__.py", line 90, in import_module
return _bootstrap._gcd_import(name[level:], package, level)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<frozen importlib._bootstrap>", line 1387, in _gcd_import
File "<frozen importlib._bootstrap>", line 1360, in _find_and_load
File "<frozen importlib._bootstrap>", line 1331, in _find_and_load_unlocked
File "<frozen importlib._bootstrap>", line 935, in _load_unlocked
File "<frozen importlib._bootstrap_external>", line 995, in exec_module
File "<frozen importlib._bootstrap>", line 488, in _call_with_frames_removed
File "/data/lijinghui/uv_projects/LLM/chat_xiaoai.py", line 17, in <module>
from vllm import SamplingParams, AsyncEngineArgs
File "/data/lijinghui/uv_projects/.venv/lib/python3.12/site-packages/vllm/__init__.py", line 12, in <module>
from vllm.engine.async_llm_engine import AsyncLLMEngine
File "/data/lijinghui/uv_projects/.venv/lib/python3.12/site-packages/vllm/engine/async_llm_engine.py", line 1196, in <module>
from vllm.v1.engine.async_llm import AsyncLLM
File "/data/lijinghui/uv_projects/.venv/lib/python3.12/site-packages/vllm/v1/engine/async_llm.py", line 25, in <module>
from vllm.v1.engine.core_client import EngineCoreClient
File "/data/lijinghui/uv_projects/.venv/lib/python3.12/site-packages/vllm/v1/engine/core_client.py", line 20, in <module>
from vllm.v1.engine.core import EngineCore, EngineCoreProc
File "/data/lijinghui/uv_projects/.venv/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 30, in <module>
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
ModuleNotFoundError: No module named 'vllm.v1.spec_decode'
@JaheimLee Thanks for reporting it! Fixed by #13359
Hello, I have already reviewed the code of this PR. May I ask if the SD of V1 is fully supported? Because I only saw the propose stage, and didn't see the score and verify stages. Also, I didn't see the relevant workers.
It's just ngram for now
It's just ngram for now
When will all the subsequent work of the SD for Version 1 be completed? Is there any time plan?
Hello, I have already reviewed the code of this PR. May I ask if the SD of V1 is fully supported? Because I only saw the propose stage, and didn't see the score and verify stages. Also, I didn't see the relevant workers.
verify and score stage is there for ngram. You can already try/test/benchmark ngram spec deocde with V1. We will add MTP support in the following weeks.