Multi-Step + Chunked Prefill with Prefill Stepping
Add Chunked-Prefill support to Multi-Step logic.
With this PR decode sequences can run in every step and will not have to be interrupted by prefills. This has an effect of reduced TPOT
Idea:
When Chunked-Prefill is enabled with Multi-Step, each scheduler iteration executes for num_scheduler_steps. Let num_scheduler_steps be 8,
In scheduling, a sequence can be,
-
Chunked Prompt : These are prompt sequences that don't require sampling in the next multi-step. Example 1: If a prompt has 65 token ids, it can be scheduled with a
token_chunk_sizeof 8. Step 1 processes token-ids 0-7, step 2 processes token-ids 8-15 and so on. At the end of 8 steps, we will have processed 64 tokens. Example 2: If a prompt has 64 token ids. The maximumtoken_chunk_sizefor multi-step is 7. This is so that the sequence can stay in thedo_sample = Falsestate through out the multi-step. Here, Step 1 processes token-ids 0-6, step 2 processes token-ids 7 - 13 and so on. At the end of 8 steps, we will have processed 56 tokens. Note:token_chunk_sizeis a per-sequence attribute and can change between the scheduled sequence. If the sequences from the examples are scheduled together, they can have a token-chunk-size of {8, 7} respectively. The per-sequencetoken_chunk_sizecan be set based on heuristics. -
Single Step Prompt : These are sequences that enter the multi-step as prefills and exit as decodes. Specifically, they are processed as a prefill in the 1st step and are treated as decodes in the rest of the steps. Example 1: If a prompt has 3 token ids, it is scheduled with a
token_chunk_sizeof 3. Step 1, treats this sequence as a prefill and processes all the 3 tokens. Step 2 - 8, treat it like a decode with a effectivetoken_chunk_sizeof 1. At the end of 8 steps, we will have processed 10 (3+7) tokens and have generated 8 tokens (1 for each step). Note: Any prompt sequence can be scheduled as a Single Step Prompt. This allows some flexibility in maximizing the batch-size. -
Decode: These are the usual decode sequences.
Every sequence type listed above have a way of being active throughout the multi-step and can be scheduled as required.
Sample logging output:
- Benchmark Serving with 1000 num_prompts and inf QPS
- Main with
--num-scheduler-steps 8
Scheduler step 0: waiting 0, running 1
- 1st step : total input toks: 13 | num prompt toks 13 | num prompt to decode toks 0 | num decode toks 0
- Seq Types 1: P 1 + Ps 0 + D 0
Scheduler step 1: waiting 965, running 35
- 1st step : total input toks: 7452 | num prompt toks 7452 | num prompt to decode toks 0 | num decode toks 0
- Seq Types 34: P 34 + Ps 0 + D 0
Scheduler step 2: waiting 932, running 68
- 1st step : total input toks: 8070 | num prompt toks 8070 | num prompt to decode toks 0 | num decode toks 0
- Seq Types 33: P 33 + Ps 0 + D 0
Scheduler step 3: waiting 896, running 93
- 1st step : total input toks: 7947 | num prompt toks 7947 | num prompt to decode toks 0 | num decode toks 0
- Seq Types 36: P 36 + Ps 0 + D 0
Scheduler step 4: waiting 857, running 122
- 1st step : total input toks: 8006 | num prompt toks 8006 | num prompt to decode toks 0 | num decode toks 0
- Seq Types 39: P 39 + Ps 0 + D 0
Scheduler step 5: waiting 812, running 153
- 1st step : total input toks: 8114 | num prompt toks 8114 | num prompt to decode toks 0 | num decode toks 0
- Seq Types 45: P 45 + Ps 0 + D 0
Scheduler step 6: waiting 777, running 177
- 1st step : total input toks: 7476 | num prompt toks 7476 | num prompt to decode toks 0 | num decode toks 0
- Seq Types 35: P 35 + Ps 0 + D 0
Scheduler step 7: waiting 739, running 200
- 1st step : total input toks: 8170 | num prompt toks 8170 | num prompt to decode toks 0 | num decode toks 0
- Seq Types 38: P 38 + Ps 0 + D 0
Scheduler step 8: waiting 694, running 231
- 1st step : total input toks: 8162 | num prompt toks 8162 | num prompt to decode toks 0 | num decode toks 0
- Seq Types 45: P 45 + Ps 0 + D 0
Scheduler step 9: waiting 654, running 256
- 1st step : total input toks: 6559 | num prompt toks 6559 | num prompt to decode toks 0 | num decode toks 0
- Seq Types 40: P 40 + Ps 0 + D 0
Scheduler step 10: waiting 642, running 256
- 1st step : total input toks: 2386 | num prompt toks 2386 | num prompt to decode toks 0 | num decode toks 0
- Seq Types 12: P 12 + Ps 0 + D 0
Scheduler step 11: waiting 629, running 256
- 1st step : total input toks: 2056 | num prompt toks 2056 | num prompt to decode toks 0 | num decode toks 0
- Seq Types 13: P 13 + Ps 0 + D 0
Scheduler step 12: waiting 628, running 256
- 1st step : total input toks: 289 | num prompt toks 289 | num prompt to decode toks 0 | num decode toks 0
- Seq Types 1: P 1 + Ps 0 + D 0
Scheduler step 13: waiting 625, running 256
- 1st step : total input toks: 619 | num prompt toks 619 | num prompt to decode toks 0 | num decode toks 0
- Seq Types 3: P 3 + Ps 0 + D 0
Scheduler step 14: waiting 625, running 256
- 1st step : total input toks: 256 | num prompt toks 0 | num prompt to decode toks 0 | num decode toks 256
- Seq Types 256: P 0 + Ps 0 + D 256
Scheduler step 22: waiting 593, running 256
- 1st step : total input toks: 5725 | num prompt toks 5725 | num prompt to decode toks 0 | num decode toks 0
- Seq Types 32: P 32 + Ps 0 + D 0
....
- This PR (with
--num-scheduler-steps 8 --enable-chunked-prefill)
Scheduler step 0: waiting 853, running 36, finished 0, swapped 0
- 1st step : total input toks: 7654 | num prompt toks 0 | num prompt to decode toks 7654 | num decode toks 0
- Rest steps : total input toks: 36 | num prompt toks 0 | num prompt to decode toks 36 | num decode toks 0
- Seq Types 36: P 0 + Ps 36 + D 0
Scheduler step 8: waiting 932, running 53, finished 15, swapped 0
- 1st step : total input toks: 7963 | num prompt toks 0 | num prompt to decode toks 7942 | num decode toks 21
- Rest steps : total input toks: 53 | num prompt toks 0 | num prompt to decode toks 32 | num decode toks 21
- Seq Types 53: P 0 + Ps 32 + D 21
Scheduler step 16: waiting 896, running 75, finished 29, swapped 0
- 1st step : total input toks: 8052 | num prompt toks 0 | num prompt to decode toks 8013 | num decode toks 39
- Rest steps : total input toks: 75 | num prompt toks 0 | num prompt to decode toks 36 | num decode toks 39
- Seq Types 75: P 0 + Ps 36 + D 39
Scheduler step 24: waiting 856, running 95, finished 49, swapped 0
- 1st step : total input toks: 7877 | num prompt toks 0 | num prompt to decode toks 7822 | num decode toks 55
- Rest steps : total input toks: 95 | num prompt toks 0 | num prompt to decode toks 40 | num decode toks 55
- Seq Types 95: P 0 + Ps 40 + D 55
Scheduler step 32: waiting 811, running 127, finished 62, swapped 0
- 1st step : total input toks: 8188 | num prompt toks 0 | num prompt to decode toks 8106 | num decode toks 82
- Rest steps : total input toks: 127 | num prompt toks 0 | num prompt to decode toks 45 | num decode toks 82
- Seq Types 127: P 0 + Ps 45 + D 82
....
Implementation Details:
Handling Single Step Prompts:
- The scheduler output orders the sequences as Chunked Prompts first, Single Step Prompts second and Decodes third. This lets us merge the Single Step Prompts into the Decodes after the 1st step with some ModelInput and AttentionMetadata updates.
Limitations:
- The Sampling requirements of the sequences should remain constant throughout the multi-step. i.e. every sequence should be in the
do_sample = Falseordo_sample = Truestate throughout the multi-step and cannot switch in-between. Switching a sequence's sampling requirement arbitrarily is non-trivial as it requires updating the token indices in the SamplingMetadata. However it is not insurmountable and could be a follow-up PR. - Implication: when Single Step Prompts contribute to most of the batch-size in a Scheduler iteration, only the 1st step runs at high efficiency. It drops in steps 2-8, as the Single Step Prompts turn into decodes.
Scheduling Logic:
- The sequence type (Chunked Prompt / Single Step Prompt) and
token_chunk_sizeare schedule-time variables that can be tweaked for maximizing the GPU efficiency. - At the time of writing (September 11, 2024), we schedule aggressively, turning every prompt-sequence into a Single Step Prompt (if the budget allows). This seems to give the best performance.
Benchmarks:
Machine : 1xH100 Config : PP = 1 , TP = 1
Benchmark Server Command (main) :
python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-8B --port 9000 --swap-space 16 --disable-log-requests --use-v2-block-manager --tensor-parallel-size 1 --worker-use-ray --pipeline-parallel-size 1 --gpu-memory-utilization 0.90 --num-scheduler-steps 8 --max-num-batched-tokens 8192
Benchmark Client Command (main):
python3 benchmarks/benchmark_serving.py --backend vllm --model meta-llama/Meta-Llama-3-8B --dataset-name sharegpt --dataset-path ./ShareGPT_V3_unfiltered_cleaned_split.json --port 9000 --num-prompts 1000
============ Serving Benchmark Result ============
Successful requests: 1000
Benchmark duration (s): 23.47
Total input tokens: 215196
Total generated tokens: 128905
Request throughput (req/s): 42.61
Output token throughput (tok/s): 5492.30
Total Token throughput (tok/s): 14661.23
---------------Time to First Token----------------
Mean TTFT (ms): 4598.22
Median TTFT (ms): 3558.79
P99 TTFT (ms): 15457.14
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 43.27
Median TPOT (ms): 32.65
P99 TPOT (ms): 311.94
---------------Inter-token Latency----------------
Mean ITL (ms): 224.63
Median ITL (ms): 221.48
P99 ITL (ms): 506.36
==================================================
Benchmark Server Command (PR):
python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-8B --port 9000 --swap-space 16 --disable-log-requests --use-v2-block-manager --tensor-parallel-size 1 --worker-use-ray --pipeline-parallel-size 1 --gpu-memory-utilization 0.90 --num-scheduler-steps 8 --enable-chunked-prefill --max-num-batched-tokens 8192
Benchmark Client Command (PR):
python3 benchmarks/benchmark_serving.py --backend vllm --model meta-llama/Meta-Llama-3-8B --dataset-name sharegpt --dataset-path ./ShareGPT_V3_unfiltered_cleaned_split.json --port 9000 --num-prompts 1000
============ Serving Benchmark Result ============
Successful requests: 1000
Benchmark duration (s): 22.92
Total input tokens: 215196
Total generated tokens: 126839
Request throughput (req/s): 43.62
Output token throughput (tok/s): 5533.22
Total Token throughput (tok/s): 14920.93
---------------Time to First Token----------------
Mean TTFT (ms): 4723.33
Median TTFT (ms): 3900.02
P99 TTFT (ms): 15082.37
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 24.14
Median TPOT (ms): 27.18
P99 TPOT (ms): 38.62
---------------Inter-token Latency----------------
Mean ITL (ms): 207.18
Median ITL (ms): 228.80
P99 ITL (ms): 560.82
==================================================
TODOs:
- Fix output metrics logging
- Support PP and TP (at the moment tested only for single GPU)
- Support LLMEngine class (at the moment only supports AsyncLLMEngine)
- Cleanup / Refactor
- More Benchmarking with TP, PP, Scheduling Policies
PR Checklist (Click to Expand)
Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.
PR Title and Classification
Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:
[Bugfix]for bug fixes.[CI/Build]for build or continuous integration improvements.[Doc]for documentation fixes and improvements.[Model]for adding a new model or improving an existing model. Model name should appear in the title.[Frontend]For changes on the vLLM frontend (e.g., OpenAI API server,LLMclass, etc.)[Kernel]for changes affecting CUDA kernels or other compute kernels.[Core]for changes in the core vLLM logic (e.g.,LLMEngine,AsyncLLMEngine,Scheduler, etc.)[Hardware][Vendor]for hardware-specific changes. Vendor name should appear in the prefix (e.g.,[Hardware][AMD]).[Misc]for PRs that do not fit the above categories. Please use this sparingly.
Note: If the PR spans more than one category, please include all relevant prefixes.
Code Quality
The PR need to meet the following code quality standards:
- We adhere to Google Python style guide and Google C++ style guide.
- Pass all linter checks. Please use
format.shto format your code. - The code need to be well-documented to ensure future contributors can easily understand the code.
- Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
- Please add documentation to
docs/source/if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.
Notes for Large Changes
Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.
What to Expect for the Reviews
The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:
- After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
- After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
- After the review, the reviewer will put an
action-requiredlabel on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR. - Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.
Thank You
Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!
👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.
Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.
To run CI, PR reviewers can do one of these:
- Add
readylabel to the PR - Enable auto-merge.
🚀
Just FYI @varun-sundar-rabindranath I went ahead and built an image from this branch so I could load test it and while I am able to get SOTA QPS for my setup (very long inputs ~4k tokens) the server quickly crashes with this message
ERROR 09-17 12:32:53 async_llm_engine.py:58] Engine background task failed
ERROR 09-17 12:32:53 async_llm_engine.py:58] Traceback (most recent call last):
ERROR 09-17 12:32:53 async_llm_engine.py:58] File "/usr/local/lib/python3.12/dist-packages/vllm/engine/async_llm_engine.py", line 48, in _log_task_completion
ERROR 09-17 12:32:53 async_llm_engine.py:58] return_value = task.result()
ERROR 09-17 12:32:53 async_llm_engine.py:58] ^^^^^^^^^^^^^
ERROR 09-17 12:32:53 async_llm_engine.py:58] File "/usr/local/lib/python3.12/dist-packages/vllm/engine/async_llm_engine.py", line 736, in run_engine_loop
ERROR 09-17 12:32:53 async_llm_engine.py:58] result = task.result()
ERROR 09-17 12:32:53 async_llm_engine.py:58] ^^^^^^^^^^^^^
ERROR 09-17 12:32:53 async_llm_engine.py:58] File "/usr/local/lib/python3.12/dist-packages/vllm/engine/async_llm_engine.py", line 676, in engine_step
ERROR 09-17 12:32:53 async_llm_engine.py:58] request_outputs = await self.engine.step_async(virtual_engine)
ERROR 09-17 12:32:53 async_llm_engine.py:58] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 09-17 12:32:53 async_llm_engine.py:58] File "/usr/local/lib/python3.12/dist-packages/vllm/engine/async_llm_engine.py", line 340, in step_async
ERROR 09-17 12:32:53 async_llm_engine.py:58] outputs = await self.model_executor.execute_model_async(
ERROR 09-17 12:32:53 async_llm_engine.py:58] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 09-17 12:32:53 async_llm_engine.py:58] File "/usr/local/lib/python3.12/dist-packages/vllm/executor/gpu_executor.py", line 185, in execute_model_async
ERROR 09-17 12:32:53 async_llm_engine.py:58] output = await make_async(self.driver_worker.execute_model
ERROR 09-17 12:32:53 async_llm_engine.py:58] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 09-17 12:32:53 async_llm_engine.py:58] File "/usr/lib/python3.12/concurrent/futures/thread.py", line 58, in run
ERROR 09-17 12:32:53 async_llm_engine.py:58] result = self.fn(*self.args, **self.kwargs)
ERROR 09-17 12:32:53 async_llm_engine.py:58] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 09-17 12:32:53 async_llm_engine.py:58] File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker_base.py", line 327, in execute_model
ERROR 09-17 12:32:53 async_llm_engine.py:58] output = self.model_runner.execute_model(
ERROR 09-17 12:32:53 async_llm_engine.py:58] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 09-17 12:32:53 async_llm_engine.py:58] File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 09-17 12:32:53 async_llm_engine.py:58] return func(*args, **kwargs)
ERROR 09-17 12:32:53 async_llm_engine.py:58] ^^^^^^^^^^^^^^^^^^^^^
ERROR 09-17 12:32:53 async_llm_engine.py:58] File "/usr/local/lib/python3.12/dist-packages/vllm/worker/multi_step_model_runner.py", line 464, in execute_model
ERROR 09-17 12:32:53 async_llm_engine.py:58] model_input = self._advance_step(
ERROR 09-17 12:32:53 async_llm_engine.py:58] ^^^^^^^^^^^^^^^^^^^
ERROR 09-17 12:32:53 async_llm_engine.py:58] File "/usr/local/lib/python3.12/dist-packages/vllm/worker/multi_step_model_runner.py", line 586, in _advance_step
ERROR 09-17 12:32:53 async_llm_engine.py:58] attn_metadata.advance_step(
ERROR 09-17 12:32:53 async_llm_engine.py:58] File "/usr/local/lib/python3.12/dist-packages/vllm/attention/backends/flash_attn.py", line 402, in advance_step
ERROR 09-17 12:32:53 async_llm_engine.py:58] ops.advance_step_flashattn(num_seqs=num_seqs,
ERROR 09-17 12:32:53 async_llm_engine.py:58] File "/usr/local/lib/python3.12/dist-packages/vllm/_custom_ops.py", line 32, in wrapper
ERROR 09-17 12:32:53 async_llm_engine.py:58] return fn(*args, **kwargs)
ERROR 09-17 12:32:53 async_llm_engine.py:58] ^^^^^^^^^^^^^^^^^^^
ERROR 09-17 12:32:53 async_llm_engine.py:58] File "/usr/local/lib/python3.12/dist-packages/vllm/_custom_ops.py", line 198, in advance_step_flashattn
ERROR 09-17 12:32:53 async_llm_engine.py:58] return torch.ops._C.advance_step_flashattn(num_seqs, num_queries,
ERROR 09-17 12:32:53 async_llm_engine.py:58] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 09-17 12:32:53 async_llm_engine.py:58] File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 1061, in __call__
ERROR 09-17 12:32:53 async_llm_engine.py:58] return self_._op(*args, **(kwargs or {}))
ERROR 09-17 12:32:53 async_llm_engine.py:58] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 09-17 12:32:53 async_llm_engine.py:58] RuntimeError: tensor: name = sampled_token_ids, shape = [55, 1] is_cont = 1, type = long int is not as expected: shape = [56, 1], type = Long
Exception in callback functools.partial(<function _log_task_completion at 0x7a5cb81df060>, error_callback=<bound method AsyncLLMEngine._error_callback of <vllm.engine.async_llm_engine.AsyncLLMEngine object at 0x7a5cac0846b0>>)
handle: <Handle functools.partial(<function _log_task_completion at 0x7a5cb81df060>, error_callback=<bound method AsyncLLMEngine._error_callback of <vllm.engine.async_llm_engine.AsyncLLMEngine object at 0x7a5cac0846b0>>)>
Hoping this is useful for your development! Can't wait for this to land in a stable release.
Hey @sam-h-bean thanks for testing this out ! Can you please share the commands you used to test ? It'd help get a repo quickly.
@varun-sundar-rabindranath the script was a garden variety locust test
locust -f vllm_server.py --headless --master --expect-workers=1 -r 2 -u 200 -t 3m --host http://vllm-sft-service.vllm-sft:8000
the invocation is just a post against /v1/chat/completions
and the k8s config is also pretty standard
containers:
- name: vllm-sft-container
image: {{ .Values.custom_vllm_image }}
args:
- "--model"
- "{{ .model }}"
- "--served-model-name"
- "sft-llama"
- "--disable-log-requests"
- "--allow-credentials"
- "--enable-prefix-caching"
- "--enable-chunked-prefill"
- "--max-num-batched-tokens"
- "32768"
- "--num-scheduler-steps"
- "10"
- "--gpu-memory-utilization"
- "0.95"
- "--tensor-parallel-size"
- "{{ index .resources.limits "nvidia.com/gpu" }}"
{{- if .extraArgs }}
{{- range .extraArgs }}
- "{{ . }}"
{{- end }}
{{- end }}
I will note that this doesn't show up until we get close to maximum load...
Sorry I can't share more. Hopefully you can still glean some useful debug info from this
Thanks for sharing @sam-h-bean 👍 I'll check it out !
[edit] I noticed you use --enable-prefix-caching with --enable-chunked-prefill - I haven't tested them together as the PR only adds supports for Chunked-Prefill with Multi-Step (--num-scheduler-steps). That could be a follow up once this lands.
Also, I put in some performance bug fixes for --tensor-parallel-size > 1 recently.
Can you please try without --enable-prefix-caching to see if that solves the issue ? Thanks.
Thanks for sharing @sam-h-bean 👍 I'll check it out ! [edit] I noticed you use
--enable-prefix-cachingwith--enable-chunked-prefill- I haven't tested them together as the PR only adds supports for Chunked-Prefill with Multi-Step (--num-scheduler-steps). That could be a follow up once this lands. Also, I put in some performance bug fixes for--tensor-parallel-size> 1 recently.Can you please try without
--enable-prefix-cachingto see if that solves the issue ? Thanks.
this did get the error to go away!
Thanks for sharing @sam-h-bean 👍 I'll check it out ! [edit] I noticed you use
--enable-prefix-cachingwith--enable-chunked-prefill- I haven't tested them together as the PR only adds supports for Chunked-Prefill with Multi-Step (--num-scheduler-steps). That could be a follow up once this lands. Also, I put in some performance bug fixes for--tensor-parallel-size> 1 recently.Can you please try without
--enable-prefix-cachingto see if that solves the issue ? Thanks.
disabling prefix caching fixes the issue!
@varun-sundar-rabindranath I am running into other issues with a similar setup
INFO 09-18 11:41:30 server.py:228] vLLM ZMQ RPC Server was interrupted.
Future exception was never retrieved
future: <Future finished exception=IndexError('list index out of range')>
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/rpc/server.py", line 115, in generate
async for request_output in results_generator:
File "/usr/local/lib/python3.12/dist-packages/vllm/engine/async_llm_engine.py", line 862, in generate
async for output in await self.add_request(
File "/usr/local/lib/python3.12/dist-packages/vllm/engine/async_llm_engine.py", line 106, in generator
raise result
File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/rpc/server.py", line 115, in generate
async for request_output in results_generator:
File "/usr/local/lib/python3.12/dist-packages/vllm/engine/async_llm_engine.py", line 862, in generate
async for output in await self.add_request(
File "/usr/local/lib/python3.12/dist-packages/vllm/engine/async_llm_engine.py", line 106, in generator
raise result
File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/rpc/server.py", line 115, in generate
async for request_output in results_generator:
File "/usr/local/lib/python3.12/dist-packages/vllm/engine/async_llm_engine.py", line 862, in generate
async for output in await self.add_request(
File "/usr/local/lib/python3.12/dist-packages/vllm/engine/async_llm_engine.py", line 106, in generator
raise result
File "/usr/local/lib/python3.12/dist-packages/vllm/engine/async_llm_engine.py", line 48, in _log_task_completion
return_value = task.result()
^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/vllm/engine/async_llm_engine.py", line 736, in run_engine_loop
result = task.result()
^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/vllm/engine/async_llm_engine.py", line 676, in engine_step
request_outputs = await self.engine.step_async(virtual_engine)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/vllm/engine/async_llm_engine.py", line 370, in step_async
step_num=self._current_step(seq_group_metadata_list) - 1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/vllm/engine/llm_engine.py", line 1356, in _current_step
current_step = seq_group_metadata_list[0].state.current_step
~~~~~~~~~~~~~~~~~~~~~~~^^^
IndexError: list index out of range
with this k8s setup
- "--model"
- "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8"
- "--served-model-name"
- "cai-llama"
- "--disable-log-requests"
- "--allow-credentials"
- "--enable-chunked-prefill"
- "--max-model-len"
- "8000"
- "--num-scheduler-steps"
- "10"
- "--quantization"
- "compressed-tensors"
- "--tensor-parallel-size"
- "{{ index .resources.limits "nvidia.com/gpu" }}"
curious if it is the combination of this experimental config and fp8 quantization
@varun-sundar-rabindranath I am running into other issues with a similar setup
INFO 09-18 11:41:30 server.py:228] vLLM ZMQ RPC Server was interrupted. Future exception was never retrieved future: <Future finished exception=IndexError('list index out of range')> Traceback (most recent call last): File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/rpc/server.py", line 115, in generate async for request_output in results_generator: File "/usr/local/lib/python3.12/dist-packages/vllm/engine/async_llm_engine.py", line 862, in generate async for output in await self.add_request( File "/usr/local/lib/python3.12/dist-packages/vllm/engine/async_llm_engine.py", line 106, in generator raise result File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/rpc/server.py", line 115, in generate async for request_output in results_generator: File "/usr/local/lib/python3.12/dist-packages/vllm/engine/async_llm_engine.py", line 862, in generate async for output in await self.add_request( File "/usr/local/lib/python3.12/dist-packages/vllm/engine/async_llm_engine.py", line 106, in generator raise result File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/rpc/server.py", line 115, in generate async for request_output in results_generator: File "/usr/local/lib/python3.12/dist-packages/vllm/engine/async_llm_engine.py", line 862, in generate async for output in await self.add_request( File "/usr/local/lib/python3.12/dist-packages/vllm/engine/async_llm_engine.py", line 106, in generator raise result File "/usr/local/lib/python3.12/dist-packages/vllm/engine/async_llm_engine.py", line 48, in _log_task_completion return_value = task.result() ^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/vllm/engine/async_llm_engine.py", line 736, in run_engine_loop result = task.result() ^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/vllm/engine/async_llm_engine.py", line 676, in engine_step request_outputs = await self.engine.step_async(virtual_engine) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/vllm/engine/async_llm_engine.py", line 370, in step_async step_num=self._current_step(seq_group_metadata_list) - 1) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/vllm/engine/llm_engine.py", line 1356, in _current_step current_step = seq_group_metadata_list[0].state.current_step ~~~~~~~~~~~~~~~~~~~~~~~^^^ IndexError: list index out of rangewith this k8s setup
- "--model" - "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8" - "--served-model-name" - "cai-llama" - "--disable-log-requests" - "--allow-credentials" - "--enable-chunked-prefill" - "--max-model-len" - "8000" - "--num-scheduler-steps" - "10" - "--quantization" - "compressed-tensors" - "--tensor-parallel-size" - "{{ index .resources.limits "nvidia.com/gpu" }}"curious if it is the combination of this experimental config and fp8 quantization
@sam-h-bean - I have been updating this PR with fixes - I fixed this particular issue this morning. Can you pull the PR again and try. Sorry about the trouble and thanks for testing :raised_hands:
Thanks for sharing @sam-h-bean 👍 I'll check it out ! [edit] I noticed you use
--enable-prefix-cachingwith--enable-chunked-prefill- I haven't tested them together as the PR only adds supports for Chunked-Prefill with Multi-Step (--num-scheduler-steps). That could be a follow up once this lands. Also, I put in some performance bug fixes for--tensor-parallel-size> 1 recently. Can you please try without--enable-prefix-cachingto see if that solves the issue ? Thanks.disabling prefix caching fixes the issue!
I might suggest throwing an error at startup time if someone tries enabling prefix caching + prefill + scheduler steps in this case
Thanks for sharing @sam-h-bean 👍 I'll check it out ! [edit] I noticed you use
--enable-prefix-cachingwith--enable-chunked-prefill- I haven't tested them together as the PR only adds supports for Chunked-Prefill with Multi-Step (--num-scheduler-steps). That could be a follow up once this lands. Also, I put in some performance bug fixes for--tensor-parallel-size> 1 recently. Can you please try without--enable-prefix-cachingto see if that solves the issue ? Thanks.disabling prefix caching fixes the issue!
I might suggest throwing an error at startup time if someone tries enabling prefix caching + prefill + scheduler steps in this case
Yup. Added in this commit.
Seeing some new interersting behavior once I pulled your latest and reran the load test. Seems that requests just stack up in pending after a few get through. I get 3 completions then it just hangs
INFO 09-18 16:10:53 metrics.py:351] Avg prompt throughput: 67.0 tokens/s, Avg generation throughput: 0.4 tokens/s, Running: 2 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 0.1%, CPU KV cache usage: 0.0%.
INFO: 10.0.0.99:58210 - "POST /v1/chat/completions HTTP/1.1" 200 OK
INFO: 10.0.0.99:58222 - "POST /v1/chat/completions HTTP/1.1" 200 OK
INFO: 10.0.0.99:58224 - "POST /v1/chat/completions HTTP/1.1" 200 OK
INFO 09-18 16:10:58 metrics.py:351] Avg prompt throughput: 70.6 tokens/s, Avg generation throughput: 5.6 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 12 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%.
INFO: 10.64.214.1:45782 - "GET /health HTTP/1.1" 200 OK
INFO: 10.65.89.48:37148 - "GET /metrics HTTP/1.1" 200 OK
INFO: 10.64.145.165:48936 - "GET /metrics HTTP/1.1" 200 OK
INFO 09-18 16:11:03 metrics.py:351] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 22 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%.
INFO 09-18 16:11:08 metrics.py:351] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 32 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%.
INFO: 10.64.214.1:37464 - "GET /health HTTP/1.1" 200 OK
INFO 09-18 16:11:13 metrics.py:351] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 42 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%.
INFO 09-18 16:11:18 metrics.py:351] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 52 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%.
INFO: 10.64.214.1:39902 - "GET /health HTTP/1.1" 200 OK
INFO 09-18 16:11:23 metrics.py:351] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 62 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%.
INFO 09-18 16:11:28 metrics.py:351] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 72 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%.
INFO: 10.64.214.1:37332 - "GET /health HTTP/1.1" 200 OK
INFO: 10.65.89.48:40538 - "GET /metrics HTTP/1.1" 200 OK
INFO: 10.64.145.165:40110 - "GET /metrics HTTP/1.1" 200 OK
INFO 09-18 16:11:33 metrics.py:351] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 82 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%.
INFO 09-18 16:11:38 metrics.py:351] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 92 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%.
INFO: 10.64.214.1:33236 - "GET /health HTTP/1.1" 200 OK
INFO 09-18 16:11:43 metrics.py:351] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 102 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%.
INFO 09-18 16:11:48 metrics.py:351] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 112 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%.
INFO: 10.64.214.1:53916 - "GET /health HTTP/1.1" 200 OK
INFO 09-18 16:11:53 metrics.py:351] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 122 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%.
INFO 09-18 16:11:58 metrics.py:351] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 132 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%.
INFO: 10.64.214.1:60554 - "GET /health HTTP/1.1" 200 OK
INFO: 10.65.89.48:43442 - "GET /metrics HTTP/1.1" 200 OK
INFO: 10.64.145.165:55592 - "GET /metrics HTTP/1.1" 200 OK
INFO 09-18 16:12:03 metrics.py:351] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 142 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%.
INFO 09-18 16:12:08 metrics.py:351] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 152 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%.
INFO: 10.64.214.1:47454 - "GET /health HTTP/1.1" 200 OK
INFO 09-18 16:12:13 metrics.py:351] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 162 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%.
INFO 09-18 16:12:18 metrics.py:351] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 172 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%.
INFO: 10.64.214.1:48256 - "GET /health HTTP/1.1" 200 OK
INFO 09-18 16:12:23 metrics.py:351] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 182 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%.
INFO 09-18 16:12:28 metrics.py:351] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 192 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%.
INFO: 10.64.214.1:44174 - "GET /health HTTP/1.1" 200 OK
INFO: 10.65.89.48:54242 - "GET /metrics HTTP/1.1" 200 OK
INFO: 10.64.145.165:58168 - "GET /metrics HTTP/1.1" 200 OK
INFO 09-18 16:12:33 metrics.py:351] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 200 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%.
INFO 09-18 16:12:38 metrics.py:351] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 200 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%.
INFO: 10.64.214.1:47416 - "GET /health HTTP/1.1" 200 OK
INFO 09-18 16:12:43 metrics.py:351] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 200 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%.
INFO 09-18 16:12:48 metrics.py:351] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 200 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%.
I pulled a trace during the hanging load test if it is of help. VLLM_TRACE_FUNCTION2.log
Thanks for sharing the trace @sam-h-bean, Ill take a look.
Also I pushed some changes based on what I thought likely was happening - When a input prompt length is > the user set --max-num-batched-tokens the PR stalls because it thinks it has scheduled enough input tokens. But in reality, it could never schedule that sequence. One way to get around this is to increase the --max-num-batched-tokens. However, I pushed a change to ignore such sequences with a user message and continue with the rest of the sequences. Can you try it out please ?
The correct way to handle this is to process the prompt in multiple chunks. I am working on it now.
Thanks for sharing the trace @sam-h-bean, Ill take a look.
Also I pushed some changes based on what I thought likely was happening - When a input prompt length is > the user set
--max-num-batched-tokensthe PR stalls because it thinks it has scheduled enough input tokens. But in reality, it could never schedule that sequence. One way to get around this is to increase the--max-num-batched-tokens. However, I pushed a change to ignore such sequences with a user message and continue with the rest of the sequences. Can you try it out please ?The correct way to handle this is to process the prompt in multiple chunks. I am working on it now.
I did indeed see this problem go away when I set max-num-batched-tokens so you are probably right!
@SolitaryThinker @alexm-neuralmagic PTAL
I'm a bit of confused with how chunked prefill works as a single step - in the PR description it mentioned token_chunk_size of prompt is scheduled as one step. If I understand correctly, chunked prefill means there will be multiple token_chunk_sizes of prompts - how do they get scheduled into the current multi-step scheduler? Can someone help explain?
I'm a bit of confused with how chunked prefill works as a single step - in the PR description it mentioned token_chunk_size of prompt is scheduled as one step. If I understand correctly, chunked prefill means there will be multiple token_chunk_sizes of prompts - how do they get scheduled into the current multi-step scheduler? Can someone help explain?
So this is not really enabling "chunked" prefill with multi-step. This PR only enables multi-step to go through chunked prefill code path. Specifically, if you enable chunked prefill and multi-step with this PR:
- Prefill and decode requests can be scheduled in a batch, or they have to be in separated batch otherwise.
- When there are decoding requests in the running queue, the scheduler will prioritize decode requests, and use the rest token budgets to schedule new prefill requests. If chunked prefill is disabled, however, new prefill requests are always prioritized.
Yea okay that makes more sense now. Thanks!
QQ: what's the definition of num_computed_tokens? For example, given a prompt [1,2,3,4,5], after the prefill phase (after process_output), one new token is generated, we get [1,2,3,4,5,6]
Before this PR: num_computed_tokens=5
After this PR: num_computed_tokens=6
I guess we should set num_computed_tokens to 5 as the last token is sampled not computed. Please let me know if my understanding is wrong.
@varun-sundar-rabindranath @comaniac
QQ: what's the definition of
num_computed_tokens? For example, given a prompt[1,2,3,4,5], after the prefill phase (afterprocess_output), one new token is generated, we get[1,2,3,4,5,6]Before this PR:num_computed_tokens=5After this PR:num_computed_tokens=6I guess we should setnum_computed_tokensto 5 as the last token is sampled not computed. Please let me know if my understanding is wrong. @varun-sundar-rabindranath @comaniac
Hi @LiuXiaoxuanPKU . I believe you are right! I looked up the definition of num_computed_tokens,
def get_num_computed_tokens(self) -> int:
"""Return the number of prefill tokens that are already computed."""
return self._num_computed_tokens
I misunderstood it as including both prefills and decode. I'll put up a PR.
QQ: what's the definition of
num_computed_tokens? For example, given a prompt[1,2,3,4,5], after the prefill phase (afterprocess_output), one new token is generated, we get[1,2,3,4,5,6]Before this PR:
num_computed_tokens=5After this PR:
num_computed_tokens=6I guess we should set
num_computed_tokensto 5 as the last token is sampled not computed. Please let me know if my understanding is wrong.@varun-sundar-rabindranath @comaniac
I guess this is related to my previous comment about adding computed_tokens by 1. @varun-sundar-rabindranath could you clarify?
@LiuXiaoxuanPKU @comaniac I have a PR https://github.com/vllm-project/vllm/pull/8950 up with a fix that reverts the updates. My bad that I totally misunderstood the semantics of num_computed_tokens. Sorry for the inconvenience! Thanks @LiuXiaoxuanPKU for catching this !