vllm
vllm copied to clipboard
[Doc]: Offline Inference Distributed Broken for TP
📚 The doc issue
https://docs.vllm.ai/en/latest/getting_started/examples/offline_inference_distributed.html This document suggests that it can be run vLLM in tensor parallel setting with Ray for use in map_batches but it is not possible currently to run a distributed inference engine with TP > 1.
Related Issue: https://github.com/vllm-project/vllm/issues/3190
Currently there is no way to get around the
ValueError: Ray does not allocate any GPUs on the driver node. Consider adjusting the Ray placement group or running the driver on a GPU node
Even if you put GPUs on the head node. The only way I can imagine this works right now is with a head group only Ray cluster which is unusable for production use cases.
This may not be present in earlier versions of vLLM (these docs were added in 0.3.1 trying to confirm that version works for TP on Ray)
Suggest a potential alternative/fix
Unsure. Really looking for advice on how to do this from @zhuohan123 who has worked on the GPU Ray Executor and @c21 who submitted the example.
I think #2406 is related. Here is a permalink to the place where the error is raised.
Whatever is going wrong happens here:
if worker_ip == driver_ip and self.driver_dummy_worker is None:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper(
worker_module_name="vllm.worker.worker",
worker_class_name="Worker",
trust_remote_code=self.model_config.trust_remote_code,
)
else:
# Else, added to the list of workers.
self.workers.append(worker)
if self.driver_dummy_worker is None:
raise ValueError(
"Ray does not allocate any GPUs on the driver node. Consider "
"adjusting the Ray placement group or running the driver on a "
"GPU node.")
self.driver_dummy_worker will only be None if no bundle was scheduled on the same node as the driver process. This is happening to me even though I have a setup where it feels like this should be impossible: I'm using a tensor_parallel_size=4 and have a head node with 2 GPUs and a worker node with 8 GPUs, and the driver script is being run from the worker with 8 GPUs.
This means that there is at least one bundle on each node, so at least one bundle should have the same IP as the driver, unless I am misunderstanding something.
did you try to update the default spec of placement group? https://github.com/vllm-project/vllm/blob/eefeb16464af5f3a61e3052d1a4128480bff7f47/vllm/executor/ray_utils.py#L110
AFAIK, it'll work normally if the driver process and bundle are on the same node, so the serving may restart if it's not the case, and will retry if it's not the case. You can also print out the driver and worker ip to check whether it's the case, you can also check the placement group where the bundle is.
@codezakh I think this is definitely where the error originates. It might have something to do with if you try to make the vLLM predictor class a remote actor (in an attempt to push it onto a worker node), then the vLLM engine tries re-wrapping in another remote call? I'm not totally sure what the bad interaction is...
@nkwangleiGIT I have certainly tried all manner of using a PG to force the actual model inference onto the worker nodes. Please also look at the linked issue whose author also tried a bunch of things: #3190. I would encourage you to try on a 3 node (1 head and 2 workers in worker group) and try to run TP inference with TP size=2 inside the map_batches. I for the life of me can not make it work. If you can, please post the code as you'd be a life saver. I'm quite sure it is broken as-is.
@nkwangleiGIT thanks for the tip! I was able to use it for a solution. @sam-h-bean I think I figured out what's going wrong and how to fix it.
In my case, the vLLM engine is inside a Ray Actor that is assigned no logical GPUs. This actor (the driver) gets scheduled onto some arbitrary node, while the vLLM workers that are children of the driver actor get scheduled onto a node that has GPUs, as they require one logical GPU per worker. In this case, no worker shares the same IP address as the driver and vLLM fails.
To fix it, you need to ensure the driver gets scheduled onto the same node as at least one of the workers.
The solution that worked for me was to do this:
- Manually calculate a placement group based on the tensor parallelism and pipeline parallelism.
- Use a
PlacementGroupSchedulingStrategywith pack=STRICT_PACKto ensure all bundles in the placement group would be placed on the same physical node andplacement_group_capture_child_tasks=Trueto ensure that all child tasks would use the placement group I created. - Instantiate vLLM Engine inside an actor that is assigned the placement group, so that the vLLM engine and all its workers are forced to use this placement group and get scheduled onto the same physical node.
I had to use this solution because the create_engine_config does not use the placement_group variable — it is not defined on EngineArgs even though it appears in ParallelConfig.
I've posted some example code at the bottom (my setup is a little different than yours). But what you want to make sure is that the call to LLM() happens inside of a ray actor that is assigned a placement group as above. You can do this, like I did, by wrapping the LLM class inside a ray.remote decorator.
from ray.util.placement_group import (
placement_group,
placement_group_table,
remove_placement_group,
)
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
import ray
from vllm import LLM
ray.init(address="auto")
# Create a placement group with 4 GPUs and at least 1 CPU.
# (assuming a tensor parallelism of 4 and pipeline parallelism of 1)
pg = placement_group(
[
{"GPU": 1, "CPU": 1},
{"GPU": 1},
{"GPU": 1},
{"GPU": 1},
],
# Ensure all bundles get placed onto the same physical node.
strategy="STRICT_PACK",
)
# Block until the placement group is ready.
ray.get(pg.ready(), timeout=10)
# Placement group is ready.
llm_constructor = ray.remote(LLM).options(
num_cpus=1, # Required so that the placement group is not ignored.
scheduling_strategy=PlacementGroupSchedulingStrategy(
# IMPORTANT: This will ensure that all child tasks spawned by
# this actor (so all of VLLMs RayGpuExecutors) will scheduled
# occupy resources of this placement group.
placement_group_capture_child_tasks=True,
placement_group=pg,
),
)
llm = llm_constructor.remote(
model=...,
tensor_parallel_size=4,
worker_use_ray=True
)
can you guys also confirm if it ever worked? (like for example 0.3.1)? Maybe it is a new feature support
@rkooo567 I wasn't able to get this to work on earlier versions, so I don't think it has ever worked (but I could be wrong).
@codezakh were you able to get a full working version using map batches to distribute work across N 4-TP Ray tasks? I was able to get something working with a single actor but not multiple yet.
I can take a look at this in a few days as well. Seems like not supporting tp > 1 limits ray batch inference capability badly (because some models cannot run on tp 1)
I also faced this issue. Anyone found the solution yet?
Hey folks, Ray Data team is adding an API to allow for this to be fixed: https://github.com/ray-project/ray/pull/45143
Once the PR is merged to Ray and released, I will update the example.
@Yard1 any updates on this?
@stikkireddy Hoping to make a PR tomorrow!
I'm not sure whether this would be of any help but you can now also use TP without Ray workers for the LLM itself, by passing distributed_executor_backend="mp" when creating the LLM.
Opened https://github.com/vllm-project/vllm/pull/4871 to fix.