[core] Multi Step Scheduling
Adds initial multi step scheduling support to vLLM. RFC: https://github.com/vllm-project/vllm/issues/6854
Current Status: 8/8: multi-node working
8/6: PP+TP working; PP+ray fixed; ~~a few single GPU perf regressions (easy fix)~~ 8/2 PP works with MP; Ready for initial pass on design 8/1 - PP is very close to working. We do get the desired interleaving of steps between microbatches which is great! 7/31 - Current branch is in very rough shape after getting the RFC design working. Will clean up after adding TP/PP support as there may be some refactors needed. However single GPU is ready for initial testing
Cmd:
python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-8B --swap-space 16 --disable-log-requests --use-v2-block-manager --tensor-parallel-size 1 --worker-use-ray --pipeline-parallel-size 1 --gpu-memory-utilization 0.90 --max_forward_calls_per_step 8
Benchmark
| Single GPU | Baseline (Req/s) | MS-4 (Req/s) | MS-8 (Req/s) | MS-12 (Req/s) | MS-16 (Req/s) |
|---|---|---|---|---|---|
| A10G 8B Llama | 5.56 | - | 6.14 | OOM | OOM |
| H100 8B Llama | 23.45 | - | 43.64 | - | - |
| ~~H100 30B Llama~~ | 8.86 | - | 13.35 | - | 13.37 |
| PP=2 | Baseline (Req/s) | MS-4 (Req/s) | MS-8 (Req/s) | MS-12 (Req/s) | MS-16 (Req/s) |
|---|---|---|---|---|---|
| A10G 8B Llama (microbatch=128) | 8.98 | - | 9.99 | - | - |
| ~~H100 8B Llama~~ | 23 | - | 31 | - | - ` |
| ~~H100 70B Llama~~ | 3.09 | 3.13 | 3.13 | - | - |
| TP=2 | Baseline (Req/s) | MS-4 (Req/s) | MS-8 (Req/s) | MS-12 (Req/s) | MS-16 (Req/s) |
|---|---|---|---|---|---|
| A10G 8B Llama | 6.11 | - | 7.02 | - | - |
| TP=2, PP=2 | Baseline (Req/s) | MS-4 (Req/s) | MS-8 (Req/s) | MS-12 (Req/s) | MS-16 (Req/s) |
|---|---|---|---|---|---|
| A10G 8B Llama (microbatch=128) | 5.99 | - | 7.15 | - | - |
TODO: Milestone 1: POC
- [X] Add
--max_forward_calls_per_stepto cli argument, engine args, and schedulerConfig - [X] Changes to
SequenceGroupStateinsequence.pyto track multi-step state. - [X] Add
MultiStepWorkerinworker/to cache multi-step state - [X] Changes to
ModelRunnerto handle multi step state - [X] Reorganize input preparation in
ModelRunnerto reduce duplicate code - [X] Async GPU->CPU transfer for sampled token
- [X] Async pythonization
- [X] Flash Attn backend
- [X] Cudagraph
- [X] Benchmarks (Ongoing)
- [x] TP
- [X] PP (works with MP and Ray, ~~mem leak somewhere with RAY~~)
- [x] PP+TP
- [X] multi-node
Milstone 2: Mergeable
- [X] Clean up data structures
- [ ] use
num_scheduler_steps - [ ] Add tests
- [ ] Tests passing
- [X] Clean up
model_runner.py, perhapsmulti_step_model_runner.py? - [x] Not a blocker, but https://github.com/vllm-project/vllm/pull/6971 will improve perf and current is included in this PR.
Follow up work:
- [ ] add ogprob support in
_pythonize_sampler_output - [ ] support chunked-prefill
- [ ] support guided decoding
- [ ] add flag to enforce synchronous pythonization (for logit processors and guided decoding)
- [ ] support spec-decode
- [ ] support prefix caching
👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.
Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).
To run full CI, you can do one of these:
- Comment
/readyon the PR - Add
readylabel to the PR - Enable auto-merge.
🚀
QQ: do you plan to split PRs to smaller pieces?
@rkooo567 If there are splits that makes sense I will definitely do that. Currently working on a small part here https://github.com/vllm-project/vllm/pull/6971
@zhuohan123 @rkooo567 @Yard1 @comaniac @alexm-neuralmagic rebased and ready for review
Working on a smaller PR that contains parts of this.
/ready
Also before merge, can you please verify the throughput (tokens/sec) gain in the following settings to make sure the PR is good performance-wise:
- ShareGPT + Llama 8B + 1x H100/A100
- ShareGPT + Llama 70B + 8x H100/A100
Also, can you add what are the dataset you are using in your original benchmark? Thanks!
Also before merge, can you please verify the throughput (tokens/sec) gain in the following settings to make sure the PR is good performance-wise:
ShareGPT + Llama 8B + 1x H100/A100 ShareGPT + Llama 70B + 8x H100/A100 Also, can you add what are the dataset you are using in your original benchmark? Thanks!
@zhuohan123
I'm using sharegpt for all the numbers. Benchmarked using the benchmark_serving.py script.
See below for single GPU numbers.
[rank0]: File "/data/woosuk/workspace/vllm/vllm/engine/output_processor/multi_step.py", line 88, in process_outputs [rank0]: assert valid_samples
@SolitaryThinker Huge thanks for the PR! QQ: I got the above error when running benchmark scripts with num_scheduler_steps > 1. Is this expected?
Hi @WoosukKwon . I see spec decode also has a class name MultiStepWorker, is there any relation with MultiStepWorker from vllm/worker/multi_step_worker.py in this PR?
Hello:) Sorry to bother, what is the state of this? Is it planned for implementation in new v1 version? Thanks a lot!
My understanding is that this shouldn't be needed in v1, as this was a stopgap for the performance bottlenecks of the scheduler in v0. Someone more up-to-date can correct me though :)
Iep:) Interesting, I see the gap still exists, but as you are saying may be way smaller than before.
My results in version 0.9.2 when using v1 scheduler with NVIDIA Nsight, batch size 128 and Llama-3.1-8B:
-
When just having decoding requests in the batch, aprox 1.2ms of gap vs 13ms of step ->
-
When having both decoding and prefill (chunked ones), aprox 1.4ms of gap vs 52ms of step ->
As you were saying, we may consider it negligible, or at least, there are other optimizations way more important than just this small optimization margin.