[Bug]: some questions regarding the usage of NCCL allreduce/broadcast/allgather/send/recv in VLLM using pycomm and torch's distributed.
Your current environment
The output of `python collect_env.py`
Your output of `python collect_env.py` here
🐛 Describe the bug
Hello VLLM expert,
While reading the VLLM code, I noticed that the pynccl.py file only wraps allreduce, send/recv, and does not wrap the allgather and broadcast operations. When the class GroupCoordinator calls allreduce, send/recv, it attempts to call the pynccl comm's allreduce, send/recv (if pynccl's disable is false; does this only take effect during CUDA graph capture?). However, when calling allgather and broadcast operations, it directly calls torch.distributed.broadcast and torch.distributed.allgather.
From reading this code, I have the following three questions:
-
Why does
pyncclnot wrapallgatherandbroadcast? Is there a specific consideration behind this? Are thenccl_comminpynccland thenccl_commintorch.distributedthe same? -
When is
pyncclcalled? From the code, it seems thatpyncclcomm can only be called when using CUDA graphs. Please confirm this. Additionally, if we completely use PyTorch's operations during CUDA graph capture, is that not feasible? Why do we need to wrap an additionalpyncclcomm? -
If a model has
allreduce,broadcast,allgather,send/recv, then during CUDA graph capture, it effectively uses NCCL operations from two different NCCL comms. Wouldn't this pose a risk?
@youkaichao I think you might have the most context in this?
We use pynccl because of two reasons:
torch.distributeddoes too many complicated things like lazy initialization, internal bucketing, which makes it very difficult to be compatible with cudagraph. For example, your warmup must run at least 11 DDP-enabled eager iterations before cudagraph capture.- pytorch 2.2 ships with nccl 2.19, which we found there is a bug. we have to use another nccl 2.18 version. fortunately, that problem was finally solved. But we still keep the pynccl wrapper in case this kind of problem occurs in the future again.
They should answer your second question.
For the rest questions:
Why does pynccl not wrap allgather and broadcast?
Because the main purpose of pynccl is for cudagraph capture, and these two operations are not part of the cudagraph. Therefore we didn't add them in pynccl. In theory, they can be added. In fact, the first version of pynccl just wrapped allreduce. It is only when we add pipeline parallel support, that we add wrappers for send/recv in pynccl.
If a model has allreduce, broadcast, allgather, send/recv, then during CUDA graph capture, it effectively uses NCCL operations from two different NCCL comms. Wouldn't this pose a risk?
That's possible. We didn't try it yet. Currently only decode phase of model forward uses cudagraph, so only allreduce will occur in cudagraph.
Thanks for detail explaination, I get all the answers , I will close this issue .
Issue:
The customer is using vllm 0.4.3 and trying to call torch.distributed.allgather in CUDA capture, which causes NCCL to hang during runtime.
Solution:
Due to insufficient warmup iterations in the code, there is an issue with the CUDA graph. It is recommended to wrap the NCCL operations by referring to the code in vllm 0.5.4 at https://github.com/vllm-project/vllm/blob/main/vllm/distributed/device_communicators/pynccl_wrapper.py, and then call it within vllm.