[Bug] Watchdog caught collective operation timeout
Checklist
- [x] 1. I have searched related issues but cannot get the expected help.
- [x] 2. The bug has not been fixed in the latest version.
- [x] 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
- [x] 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/sglang/discussions/new/choose Otherwise, it will be closed.
- [x] 5. Please use English, otherwise it will be closed.
Describe the bug
Run DeepSeek-R1 in 4 A100 nodes, 8*A100-80G in each. In one node, I see the following logs:
[2025-02-22 08:19:45 TP26] Load weight begin. avail mem=78.38 GB
[2025-02-22 08:19:45 TP29] Load weight begin. avail mem=78.38 GB
[2025-02-22 08:19:45 TP27] Load weight begin. avail mem=78.38 GB
[2025-02-22 08:19:45 TP25] Load weight begin. avail mem=78.38 GB
[2025-02-22 08:19:45 TP31] Load weight begin. avail mem=78.41 GB
[2025-02-22 08:19:45 TP30] Load weight begin. avail mem=78.38 GB
[2025-02-22 08:19:45 TP24] Load weight begin. avail mem=78.41 GB
[2025-02-22 08:19:45 TP28] Load weight begin. avail mem=78.38 GB
Cache shape torch.Size([163840, 64])
Cache shape torch.Size([163840, 64])
Cache shape torch.Size([163840, 64])
Cache shape torch.Size([163840, 64])
Cache shape torch.Size([163840, 64])
Cache shape torch.Size([163840, 64])
Cache shape torch.Size([163840, 64])
Cache shape torch.Size([163840, 64])
[2025-02-22 08:21:43 TP26] Load weight end. type=DeepseekV3ForCausalLM, dtype=torch.bfloat16, avail mem=37.21 GB
[2025-02-22 08:21:43 TP29] Load weight end. type=DeepseekV3ForCausalLM, dtype=torch.bfloat16, avail mem=37.21 GB
[2025-02-22 08:21:43 TP31] Load weight end. type=DeepseekV3ForCausalLM, dtype=torch.bfloat16, avail mem=37.23 GB
[2025-02-22 08:21:43 TP24] Load weight end. type=DeepseekV3ForCausalLM, dtype=torch.bfloat16, avail mem=37.23 GB
[2025-02-22 08:21:43 TP27] Load weight end. type=DeepseekV3ForCausalLM, dtype=torch.bfloat16, avail mem=37.21 GB
[2025-02-22 08:21:43 TP28] Load weight end. type=DeepseekV3ForCausalLM, dtype=torch.bfloat16, avail mem=37.21 GB
[2025-02-22 08:21:44 TP30] Load weight end. type=DeepseekV3ForCausalLM, dtype=torch.bfloat16, avail mem=37.21 GB
[2025-02-22 08:21:50 TP25] Load weight end. type=DeepseekV3ForCausalLM, dtype=torch.bfloat16, avail mem=37.21 GB
[rank29]:[E222 08:31:44.059609287 ProcessGroupNCCL.cpp:616] [Rank 29] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=2, OpType=ALLREDUCE, NumelIn=1, NumelOut=1, Timeout(ms)=600000) ran for 600023 milliseconds before timing out.
[rank29]:[E222 08:31:44.060357685 ProcessGroupNCCL.cpp:1785] [PG ID 0 PG GUID 0(default_pg) Rank 29] Exception (either an error or timeout) detected by watchdog at work: 2, last enqueued NCCL work: 2, last completed NCCL work: 1.
[rank29]:[E222 08:31:44.060436972 ProcessGroupNCCL.cpp:1834] [PG ID 0 PG GUID 0(default_pg) Rank 29] Timeout at NCCL work: 2, last enqueued NCCL work: 2, last completed NCCL work: 1.
[rank29]:[E222 08:31:44.060491163 ProcessGroupNCCL.cpp:630] [Rank 29] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[rank29]:[E222 08:31:44.060551032 ProcessGroupNCCL.cpp:636] [Rank 29] To avoid data inconsistency, we are taking the entire process down.
[rank24]:[E222 08:31:44.061790357 ProcessGroupNCCL.cpp:616] [Rank 24] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=2, OpType=ALLREDUCE, NumelIn=1, NumelOut=1, Timeout(ms)=600000) ran for 600015 milliseconds before timing out.
[rank24]:[E222 08:31:44.062657112 ProcessGroupNCCL.cpp:1785] [PG ID 0 PG GUID 0(default_pg) Rank 24] Exception (either an error or timeout) detected by watchdog at work: 2, last enqueued NCCL work: 2, last completed NCCL work: 1.
and I know that it is caused by another node:
2025-02-22 08:19:45 TP11] Load weight begin. avail mem=78.38 GB
[2025-02-22 08:19:45 TP8] Load weight begin. avail mem=78.41 GB
[2025-02-22 08:19:45 TP14] Load weight begin. avail mem=78.38 GB
[2025-02-22 08:19:45 TP9] Load weight begin. avail mem=78.38 GB
[2025-02-22 08:19:45 TP12] Load weight begin. avail mem=78.38 GB
[2025-02-22 08:19:45 TP15] Load weight begin. avail mem=78.41 GB
[2025-02-22 08:19:45 TP13] Load weight begin. avail mem=78.38 GB
[2025-02-22 08:19:45 TP10] Load weight begin. avail mem=78.38 GB
Cache shape torch.Size([163840, 64])
Cache shape torch.Size([163840, 64])
Cache shape torch.Size([163840, 64])
Cache shape torch.Size([163840, 64])
Cache shape torch.Size([163840, 64])
Cache shape torch.Size([163840, 64])
Cache shape torch.Size([163840, 64])
Cache shape torch.Size([163840, 64])
[2025-02-22 08:36:13 TP15] Load weight end. type=DeepseekV3ForCausalLM, dtype=torch.bfloat16, avail mem=37.23 GB
[2025-02-22 08:36:13 TP12] Load weight end. type=DeepseekV3ForCausalLM, dtype=torch.bfloat16, avail mem=37.21 GB
[2025-02-22 08:36:13 TP13] Load weight end. type=DeepseekV3ForCausalLM, dtype=torch.bfloat16, avail mem=37.21 GB
[2025-02-22 08:36:13 TP11] Load weight end. type=DeepseekV3ForCausalLM, dtype=torch.bfloat16, avail mem=37.21 GB
[2025-02-22 08:36:13 TP9] Load weight end. type=DeepseekV3ForCausalLM, dtype=torch.bfloat16, avail mem=37.21 GB
[2025-02-22 08:36:13 TP14] Load weight end. type=DeepseekV3ForCausalLM, dtype=torch.bfloat16, avail mem=37.21 GB
[2025-02-22 08:36:13 TP8] Load weight end. type=DeepseekV3ForCausalLM, dtype=torch.bfloat16, avail mem=37.23 GB
ly-node-214-23-54-4:99:915 [0] misc/socket.cc:50 NCCL WARN socketProgress: Connection closed by remote peer ly-node-214-23-54-5<54666>
ly-node-214-23-54-4:99:915 [0] NCCL INFO misc/socket.cc:752 -> 6
ly-node-214-23-54-4:99:915 [0] NCCL INFO transport/net_socket.cc:474 -> 6
ly-node-214-23-54-4:99:915 [0] NCCL INFO transport/net.cc:1302 -> 6
ly-node-214-23-54-4:99:915 [0] NCCL INFO proxy.cc:698 -> 6
ly-node-214-23-54-4:99:915 [0] NCCL INFO proxy.cc:878 -> 6 [Progress Thread]
It costs 17 minutes to load weights in this node, which exceeds the watchdog timeout, which is 10 minutes.
however, neither setting TORCH_NCCL_ENABLE_MONITORING=0 nor setting TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC=10000000 is useful.
then, I read the source code of SGLang, and I think this time is determined by _DEFAULT_PG_NCCL_TIMEOUT, which is dependent on pytorch.
finally, I think I find the definition of the WATCHDOG TIMEOUT time. it is defined in /usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp , line 132:
constexpr auto kProcessGroupNCCLDefaultTimeout =
std::chrono::milliseconds(10 * 60 * 1000);
however, change this constexpr directly is useless. i wonder how can I change the WATCHDOG TIMEOUT time, or are there some method that I can just shutdown the watchdog timer?
i think this is an important bug, and I begging for fixing it.
Reproduction
the command is:
TORCH_NCCL_ENABLE_MONITORING=0 TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC=10000000 python3 -m sglang.launch_server --model-path /DeepSeek-R1-BF16 --watchdog-timeout 0 --tensor-parallel-size 32 --context-length 32768 --dist-init-addr 192.168.0.2:6379 --nnodes 4 --node-rank 2 --trust-remote-code
and the envs are:
-e GLOO_SOCKET_IFNAME='bond0' \
-e TP_SOCKET_IFNAME='bond0' \
-e NCCL_SOCKET_IFNAME='bond0' \
-e NCCL_DEBUG='info' \
-e NCCL_NET='Socket' \
-e NCCL_IB_DISABLE='0' \
-e TORCH_NCCL_ENABLE_MONITORING='0' \
-e TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC='10000' \
-e TORCH_NCCL_BLOCKING_WAIT='0' \
Environment
the envs info:
INFO 02-22 08:52:08 __init__.py:190] Automatically detected platform cuda.
Python: 3.10.12 (main, Jan 17 2025, 14:35:34) [GCC 11.4.0]
CUDA available: True
GPU 0,1,2,3,4,5,6,7: NVIDIA A100-SXM4-80GB
GPU 0,1,2,3,4,5,6,7 Compute Capability: 8.0
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 12.4, V12.4.131
CUDA Driver Version: 550.90.07
PyTorch: 2.5.1+cu124
sgl_kernel: 0.0.3.post6
flashinfer: 0.2.1.post2+cu124torch2.5
triton: 3.1.0
transformers: 4.48.3
torchao: 0.8.0
numpy: 1.26.4
aiohttp: 3.11.12
fastapi: 0.115.8
hf_transfer: 0.1.9
huggingface_hub: 0.28.1
interegular: 0.3.3
modelscope: 1.23.0
orjson: 3.10.15
packaging: 24.2
psutil: 7.0.0
pydantic: 2.10.6
multipart: 0.0.20
zmq: 26.2.1
uvicorn: 0.34.0
uvloop: 0.21.0
vllm: 0.7.2
openai: 1.63.2
tiktoken: 0.9.0
anthropic: 0.45.2
decord: 0.6.0
NVIDIA Topology:
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7 NIC0 NIC1 NIC2 NIC3 NIC4 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X NV12 NV12 NV12 NV12 NV12 NV12 NV12 PXB NODE NODE SYS SYS 0-31,64-95 0 N/A
GPU1 NV12 X NV12 NV12 NV12 NV12 NV12 NV12 PXB NODE NODE SYS SYS 0-31,64-95 0 N/A
GPU2 NV12 NV12 X NV12 NV12 NV12 NV12 NV12 NODE PXB NODE SYS SYS 0-31,64-95 0 N/A
GPU3 NV12 NV12 NV12 X NV12 NV12 NV12 NV12 NODE PXB NODE SYS SYS 0-31,64-95 0 N/A
GPU4 NV12 NV12 NV12 NV12 X NV12 NV12 NV12 SYS SYS SYS PXB NODE 32-63,96-127 1 N/A
GPU5 NV12 NV12 NV12 NV12 NV12 X NV12 NV12 SYS SYS SYS PXB NODE 32-63,96-127 1 N/A
GPU6 NV12 NV12 NV12 NV12 NV12 NV12 X NV12 SYS SYS SYS NODE PXB 32-63,96-127 1 N/A
GPU7 NV12 NV12 NV12 NV12 NV12 NV12 NV12 X SYS SYS SYS NODE PXB 32-63,96-127 1 N/A
NIC0 PXB PXB NODE NODE SYS SYS SYS SYS X NODE NODE SYS SYS
NIC1 NODE NODE PXB PXB SYS SYS SYS SYS NODE X NODE SYS SYS
NIC2 NODE NODE NODE NODE SYS SYS SYS SYS NODE NODE X SYS SYS
NIC3 SYS SYS SYS SYS PXB PXB NODE NODE SYS SYS SYS X NODE
NIC4 SYS SYS SYS SYS NODE NODE PXB PXB SYS SYS SYS NODE X
Legend:
X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks
NIC Legend:
NIC0: mlx5_0
NIC1: mlx5_1
NIC2: mlx5_2
NIC3: mlx5_3
NIC4: mlx5_4
ulimit soft: 1048576
cc @zhyncs
and i find that many users see the same bug. i wonder that if there is any method to fix this. cc @burling @Orangels @CallmeZhangChenchen @aooxin in issue 3368 cc @sitabulaixizawaluduo in issue 3360
Does it work to set --watchdog-timeout 3600? This allows for 1 hour timeout.
Does it work to set
--watchdog-timeout 3600? This allows for 1 hour timeout.
no, it doesn't work. the default value of watchdog-timeout is 300 (in sglang\srt\server_args.py, line 81), so the default value is 5 min, which is different from the default WATCHDOG TIMEOUT value, i.e., 10min. and I try to set this value to 10000000 and 0, but they don't work.
thank you for your help.
I suspect that this is not due to slow weight loading, there might be some communication issues cross-node. I will look into it.
I suspect that this is not due to slow weight loading, there might be some communication issues cross-node. I will look into it.
i appreciate for your help. i hope that you can fix this problem ❤ and make me enjoy the features of SGLang.
@FrankLeeeee @teadross , we encountered this issue with 4 A800 nodes. It is the slow weight loading, which causes torch ddp c10d watchdog timeout.
You need to update sglang code to set ddp timeout longer. I load 1.3T BF16 weights with 20 minutes(the default setting is 600 seconds). Then you could load weight with success. Reference below code change.
https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/distributed/parallel_state.py#L978
from datetime import timedelta
# this backend is used for WORLD
torch.distributed.init_process_group(
backend=backend,
init_method=distributed_init_method,
world_size=world_size,
rank=rank,
timeout=timedelta(seconds=3600),
)
From west-hpc.com team.
Ok noted, I guess we can make it a server arg.
i will try it later. thx!
I have created a pr #3803
@teadross can you pull the latest main branch and try again? This bug seems to be solved according to #3836
This issue has been automatically closed due to inactivity. Please feel free to reopen it if needed.