vllm
vllm copied to clipboard
[V1] DP scale-out (2/N): Decouple engine process management and comms
This decouples the management of engine processes from the IPC, and adds support for a mix of local and/or remote engines (where remote are running on a different node).
When there are any remote engines, tcp transport is used for the zmq sockets, otherwise ipc (domain-socket based) is used.
Engines are bootstrapped with the input queue address and use this to perform a handshake with the front-end running in the head node, which provides other necessary configuration details:
Front-End Engine Core (N)
| |
(1) | <-------- HELLO ------------ |
(2) | ---- config / conn info ---> | (address of output queue and data parallel torch process group)
| |
| [ engine init - load model ]
| |
(3) | <-------- READY ------------ |
There is a new --headless option for vllm serve to run on secondary nodes, which launches one or more engines (data parallel ranks) without the front-end / API server).
Examples
This will run DP=4 with DP ranks 0 and 1 on the head node and ranks 2 and 3 on the second node:
# Node 1 (with ip address 10.99.48.128)
vllm serve $MODEL --data-parallel-size 4 --data-parallel-size-local 2 \
--data-parallel-address 10.99.48.128 --data-parallel-rpc-port 13345
# Node 2
vllm serve $MODEL --headless --data-parallel-size-local 2 --data-parallel-start-rank 2 \
--data-parallel-address 10.99.48.128 --data-parallel-rpc-port 13345
This will run DP=4 with only the API server on the first node and all engines on the second node:
# Node 1 (with ip address 10.99.48.128)
vllm serve $MODEL --data-parallel-size 4 --data-parallel-size-local 0 \
--data-parallel-address 10.99.48.128 --data-parallel-rpc-port 13345
# Node 2
vllm serve $MODEL --headless --data-parallel-size-local 4 \
--data-parallel-address 10.99.48.128 --data-parallel-rpc-port 13345
It's assumed that local engine ranks (if any) will always be lower than remote engine ranks. Note that it's not actually necessary to specify the (global) dp size on the secondary nodes since this is obtained during the handshake. It would be straightforward to extend this to any other config which must be consistent across the ranks - so that you only need to specify it in the head node cli args.
TODO (this PR):
- [ ] Feedback
- [x] Some code simplification
- [x] Fix offline DP compatiblity
- [ ] CI tests
Next PR:
- API server scale-out
👋 Hi! Thank you for contributing to the vLLM project.
💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.
Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.
To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.
🚀
My pr https://github.com/vllm-project/vllm/pull/15863 is based on Ray's support for multi-node DP. Could you provide some feedback for modifications so we can merge with your PR?
is this "frontend" is one api server process?
is this "frontend" is one api server process?
So far, yes. I've been trying to open incremental PRs, working on multi api server right now as the next one (on top of this PR). It won't change much in terms of the deployment semantics though - just maybe one additional arg to specify how many api server procs (applies only to head node and so mutually exclusive with --headless).
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @njhill.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
TODO (this PR):
- [ ] Feedback
- [x] Some code simplification
- [x] Fix offline DP compatiblity
- [ ] CI tests
Next PR:
- API server scale-out
Hey @njhill , just want to clarify - when you include "CI tests" in the TODO, does this mean you will be adding new unit tests? Maybe it would be valuable to unit-test the communication process in scaled-out scenarios for example.
TODO (this PR):
- [ ] Feedback
- [x] Some code simplification
- [x] Fix offline DP compatiblity
- [ ] CI tests
Next PR:
- API server scale-out
Hey @njhill , just want to clarify - when you include "CI tests" in the TODO, does this mean you will be adding new unit tests? Maybe it would be valuable to unit-test the communication process in scaled-out scenarios for example.
Yes, plan to add tests to cover launching some engine(s) remotely, but wanted to get some feedback on the design first.
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @njhill.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
Thanks @simon-mo, @youkaichao. Don't merge quite yet, I will push small updates and a multi-node CI test today.
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @njhill.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
related to my last comment -- I have some changes in a private branch that allow turning off the pickle fallback in our custom msgpack encoder/decoder. I'm going to push a PR with just those changes. We can try using that here to see if anything breaks and perhaps even have an environment variable that allows turning it back on if something unexpected breaks -- VLLM_ALLOW_INSECURE_PICKLE_FALLBACK or something ...
posted https://github.com/vllm-project/vllm/pull/17427 with the first step -- allowing the pickle fallback to be turned off in the custom msgpack encoder/decoder plus an env var to force it back on if something breaks unexpectedly
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @njhill.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @njhill.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
We've talked about this some offline, but for the record, I would prefer that we don't introduce any additional multi-node interfaces that are not secured in any way, and worse, allow arbitrary code execution via
pickle. I understand the argument that PyTorch says its distributed communications are insecure, but that's not a great reason to make our own code worse and harder to fix in the future.
The security concerns are hopefully addressed now that we have https://github.com/vllm-project/vllm/pull/17490 merged.
@russellb @youkaichao can you please help final round of review?
We've talked about this some offline, but for the record, I would prefer that we don't introduce any additional multi-node interfaces that are not secured in any way, and worse, allow arbitrary code execution via
pickle. I understand the argument that PyTorch says its distributed communications are insecure, but that's not a great reason to make our own code worse and harder to fix in the future.The security concerns are hopefully addressed now that we have #17490 merged.
Confirmed, yes -- thank you for working with me on this!
Fixing rebase issues
I've opened a separate PR for one of the new CI test failures which is unrelated: https://github.com/vllm-project/vllm/pull/18007
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @njhill.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
When will vllm.entrypoints.openai.api_server support --headless and --data-parallel-start-rank? @njhill
Hi, does currently this approach support IPv6 address?
Hi, does currently this approach support IPv6 address?
I've been trying to fix IPv6 support anywhere I see that doesn't work correctly. If you see it not working anywhere, let me know. I looked at it in this PR and think it should be OK now, though I didn't try it myself.
@russellb Hi, I think this PR can fix the problem: https://github.com/vllm-project/vllm/pull/18991. I have tried it.