vllm
vllm copied to clipboard
Remove Ray for the dependency
Using Ray in here is considering to be an overkill. You can create a multi-process distributed environment easily using torchdist or mpi launch. Internally you can leverage NCCL or MPI communication protocol for inter-process communications.
I would disagree. Ray has extended serving support, and extends towards LLMs and streaming. Also there is tools like https://ray-project.github.io/kuberay/components/operator/
Perhaps isolate ray or make it optional.
You can run vLLM without ray on a single GPU.
For distributed settings, vLLM has a centralized scheduler and memory manager, whose control messages need to be broadcasted to all model workers. We find ray is the easiest library to do so. If using MPI, we need to implement this centralized logic in an SPMD program, which is not a very natural choice.
We are open to an elegant solution without using ray. If you are interested, please feel free to share the idea here and contribute to the repo!
@michaelfeil I believe what you said has nothing to do with how do we do Tensor Parallel communication. Ray is a good package for starting a cross node (machine) cluster with shared space on memory, but not ideally for a standalone machine with distributed GPUs. Most of the currently standard library for tensor parallel libraries are using either MPI, torch dist, or simply launch multiple processes to start the inference. For example
- DeepSpeed launcher: Just spin multiple processes and have socket listener, internally it used NCCL for communication
- FasterTransformer: Use MPI as external launcher, using NCCL inside for communication
- Pippy: use Torch dist as launcher (pure NCCL environment)
Having a strong tie to a distribute framework generally is not a good idea. Introduce too much overhead and unused dependenecies.
@lanking520 Missed that point and got you a bit wrong. Thanks for additional explanations.
@lanking520 Thanks for your comment. We indeed use NCCL for cross-GPU tensor communication. However, in vLLM, we also need to pass several metadata ("control messages") from the scheduler to workers. The metadata are basically Python objects like list and dictionaries stored in CPU memory. We found it convenient to use Ray for sending this metadata to the workers. To my understanding, it's not easy to use MPI or torch.dist to send these Python objects to workers, is it?
@WoosukKwon MPI should be easy enough to do that if you want, there is mpi4py
python package allow you to send serialized object. torch.dist can only pass through tensors that makes it a bit hard for you to do.
https://mpi4py.readthedocs.io/en/stable/overview.html#communicating-python-objects-and-array-data
For torch dist, if you can serialize the object to bytes and pass it through with tensor, this will also work. But the cpu2gpu and gpu2cpu copy is kind of expensive.
Given your case that you want the message to stay in CPU, then just use mpi interface for the access should work.
You can use MPI to spin the dist environment, while enabling torch.dist("nccl") for communication. So you will have GPU2GPU and CPU2CPU communication at the same time
If it is purely cross CPU process communication, you can also do
https://docs.python.org/3/library/multiprocessing.shared_memory.html
shared memory access control. Launch instance in pure spin multi-processes or use MPI, or torch.dist. All should work. Even you launch with Ray will work in this case... But just a bit hard to maintain the memory since potentially illegal read/write access
@lanking520 Thanks for your comment. We indeed use NCCL for cross-GPU tensor communication. However, in vLLM, we also need to pass several metadata ("control messages") from the scheduler to workers. The metadata are basically Python objects like list and dictionaries stored in CPU memory. We found it convenient to use Ray for sending this metadata to the workers. To my understanding, it's not easy to use MPI or torch.dist to send these Python objects to workers, is it?
PyTorch distributed supports to broadcasting of pickled python objects like list, etc... Ray isn t required for it.
Thanks for the advice @lanking520! We will take that into account. Currently, we are focusing on fixing bugs & adding requested models. After these are addressed, we will look into the alternatives you suggested, and see whether they will improve the UX in vLLM.
@lanking520 Thanks for your comment. We indeed use NCCL for cross-GPU tensor communication. However, in vLLM, we also need to pass several metadata ("control messages") from the scheduler to workers. The metadata are basically Python objects like list and dictionaries stored in CPU memory. We found it convenient to use Ray for sending this metadata to the workers. To my understanding, it's not easy to use MPI or torch.dist to send these Python objects to workers, is it?
I guess you can do this with torch.distributed.broadcast_object_list
or torch.distributed.gather_object
?
We've now opened a PR to add support for this #3466