MPI_Init_thread multi explicit allocated VCIs Hang with PyTorch
Hi!
I have been testing PyTorch ProcessGroup backend with MPIX Stream Extension with a goal of enqueuing collective operations onto a separate device stream to enable fine grained overlapping (just like NCCL). However, I was experiencing a hang during MPI_Init_thread in the ucx backend when reserving an explicit vci for device stream, well before it reaches the collective enqueue call.
The testing is done using a pytorch level collective benchmark here and the following launch commands:
export MASTER_PORT=$(get_free_port)
export WORLD_SIZE=$(($SLURM_NNODES * $SLURM_GPUS_PER_NODE))
echo "WORLD_SIZE="$WORLD_SIZE
master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export MASTER_ADDR=$master_addr
echo "MASTER_ADDR="$MASTER_ADDR
cmd="\
mpirun -l -np 2 -ppn 2 \
-genv MPIR_CVAR_ENABLE_GPU=1 \
-genv MPIR_CVAR_CH4_RESERVE_VCIS=1 \
-genv UCX_LOG_LEVEL=debug \
-genv UCX_LOG_FILE=$PROJECT/logs/mpich-test-ucx.%h.%p.log \
-genv MPIR_CVAR_CH4_RUNTIME_CONF_DEBUG=1 \
-genv MPIR_CVAR_DEBUG_SUMMARY=1 \
python cookbook/benchmarks/communication/all_reduce.py --scan \
--backend=mpi \
--dist=torch \
--maxsize=4 \
--warmups=2 \
--trials=3"
output:
MPICH Version: 5.0.0a1
MPICH Release date: unreleased development copy
MPICH ABI: 0:0:0
MPICH Device: ch4:ucx
MPICH configure: --prefix=/home/xu.3304/project/mpix-stream/install/mpich-7f00e56-cuda12.6 --disable-fortran --disable-f77 --enable-cuda --with-cuda=/opt/cuda/12.6 --with-device=ch4:ucx --with-ucx=embedded CC=/opt/gcc/13.3.0/bin/gcc CXX=/opt/gcc/13.3.0/bin/g++ LDFLAGS=-lstdc++ CFLAGS=-I/opt/cuda/12.6/include
MPICH CC: /opt/gcc/13.3.0/bin/gcc -I/opt/cuda/12.6/include -O2
MPICH CXX: /opt/gcc/13.3.0/bin/g++ -O2
MPICH F77: /opt/gcc/13.3.0/bin/gfortran
MPICH FC: /opt/gcc/13.3.0/bin/gfortran
MPICH features: threadcomm
WORLD_SIZE=2
MASTER_ADDR=a100-10
mpirun -l -np 2 -ppn 2 -genv MPIR_CVAR_ENABLE_GPU=1 -genv MPIR_CVAR_CH4_RESERVE_VCIS=1 -genv UCX_LOG_LEVEL=debug -genv UCX_LOG_FILE=/home/xu.3304/project/mpix-stream/logs/mpich-test-ucx.%h.%p.log -genv MPIR_CVAR_CH4_RUNTIME_CONF_DEBUG=1 -genv MPIR_CVAR_DEBUG_SUMMARY=1 python cookbook/benchmarks/communication/all_reduce.py --scan --backend=mpi --dist=torch --maxsize=4 --warmups=2 --trials=3
[0] ==== GPU Init (CUDA) ====
[0] device_count: 2
[0] CUDA_VISIBLE_DEVICES: 0,1
[0] =========================
[0] ==== UCX netmod Capability ====
[0] MPIDI_UCX_CONTEXT_ID_BITS: 16
[0] MPIDI_UCX_RANK_BITS: 16
[0] tag_bits: 31
[0] ===============================
[0] ==== Various sizes and limits ====
[0] sizeof(MPIDI_per_vci_t): 128
[0] ==== collective selection ====
[0] MPIR_CVAR_DEVICE_COLLECTIVES: percoll
[0] MPIR: MPII_coll_generic_json
[0] MPID: MPIDI_coll_generic_json
[0] MPID (GPU): MPIDI_coll_generic_json
[0] num_vcis: 2
The MPICH being used is on commit 7f00e56 with cuda 12.6 using the embedded ucx, the pytorch version is 2.7.1 built from source. For your reference, the hang occured during initialization here
A full back trace by doing: gdb -q -batch -p <PID> -ex "thread apply all bt full"
I omitted redundant traces from other threads
Thread 66 (Thread 0x1493ef53e700 (LWP 192741)):
#0 0x000014951cf17307 in epoll_wait () from /lib64/libc.so.6
No symbol table info available.
#1 0x0000149502c62713 in ucs_event_set_wait (event_set=0x55b0564bb1f0, num_events=num_events@entry=0x1493ef53de9c, timeout_ms=99, event_set_handler=event_set_handler@entry=0x149502c3bab0 <ucs_async_thread_ev_handler>, arg=arg@entry=0x1493ef53dea0) at sys/event_set.c:198
events = 0x1493ef53dd50
nready = <optimized out>
i = <optimized out>
io_events = <optimized out>
__func__ = "ucs_event_set_wait"
#2 0x0000149502c3c06e in ucs_async_thread_func (arg=0x55b056453160) at async/thread.c:131
thread = 0x55b056453160
last_time = 15654446212342260
curr_time = 15654446212342260
timer_interval = 199665882
time_spent = <optimized out>
is_missed = 0
timeout_ms = <optimized out>
status = <optimized out>
num_events = 16
cb_arg = {thread = 0x55b056453160, is_missed = 0x1493ef53de98}
__func__ = "ucs_async_thread_func"
#3 0x000014951d9401ca in start_thread () from /lib64/libpthread.so.0
No symbol table info available.
#4 0x000014951ce118d3 in clone () from /lib64/libc.so.6
No symbol table info available.
Thread 65 (Thread 0x1493fa8a5700 (LWP 192729)):
#0 0x000014951cf0aac1 in poll () from /lib64/libc.so.6
No symbol table info available.
#1 0x00001495037c5e4f in ?? () from /lib64/libcuda.so.1
No symbol table info available.
#2 0x000014950389912f in ?? () from /lib64/libcuda.so.1
No symbol table info available.
#3 0x00001495037c20f3 in ?? () from /lib64/libcuda.so.1
No symbol table info available.
#4 0x000014951d9401ca in start_thread () from /lib64/libpthread.so.0
No symbol table info available.
#5 0x000014951ce118d3 in clone () from /lib64/libc.so.6
No symbol table info available.
Thread 64 (Thread 0x1493fd842700 (LWP 192728)):
#0 0x000014951d94647c in pthread_cond_wait@@GLIBC_2.3.2 () from /lib64/libpthread.so.0
No symbol table info available.
#1 0x0000149485d6a81b in blas_thread_server () from /home/xu.3304/project/mpix-stream/miniconda3/envs/torch-2.7.1-mpich/lib/python3.12/site-packages/numpy/_core/../../numpy.libs/libscipy_openblas64_-8fb3d286.so
No symbol table info available.
#2 0x000014951d9401ca in start_thread () from /lib64/libpthread.so.0
No symbol table info available.
#3 0x000014951ce118d3 in clone () from /lib64/libc.so.6
No symbol table info available.
...
Thread 1 (Thread 0x14951dd6e400 (LWP 192599)):
#0 0x0000149502c4550e in ucs_recursive_spin_lock (lock=0x55b05642d240) at /home/xu.3304/project/mpix-stream/mpich/modules/ucx/src/ucs/type/spinlock.h:90
self = 22630683304960
self = <optimized out>
#1 ucs_callbackq_enter (cbq=0x55b0564e0a50) at datastruct/callbackq.c:87
No locals.
#2 ucs_callbackq_spill_elems_dispatch (cbq=0x55b0564e0a50) at datastruct/callbackq.c:382
count = 0
spill_elem = <optimized out>
spill_elem_idx = 0
priv = 0x55b05642d240
num_spill_elems = 2
tmp_elem = {cb = <optimized out>, arg = 0x55b0564e1160}
priv = <optimized out>
num_spill_elems = <optimized out>
count = <optimized out>
spill_elem = <optimized out>
tmp_elem = <optimized out>
spill_elem_idx = <optimized out>
#3 ucs_callbackq_proxy_callback (arg=0x55b0564e0a50) at datastruct/callbackq.c:476
cbq = 0x55b0564e0a50
count = <optimized out>
__func__ = "ucs_callbackq_proxy_callback"
#4 0x000014950324e48a in ucs_callbackq_dispatch (cbq=<optimized out>) at /home/xu.3304/project/mpix-stream/mpich/modules/ucx/src/ucs/datastruct/callbackq.h:215
elem = 0x55b0564e0aa0
cb = <optimized out>
count = 0
#5 uct_worker_progress (worker=<optimized out>) at /home/xu.3304/project/mpix-stream/mpich/modules/ucx/src/uct/api/uct.h:2823
No locals.
#6 ucp_worker_progress (worker=0x55b055e39a90) at core/ucp_worker.c:3059
count = <optimized out>
__func__ = "ucp_worker_progress"
#7 0x000014950ca3b1ed in flush_all () from /home/xu.3304/project/mpix-stream/install/mpich-7f00e56-cuda12.6/lib/libmpi.so.0
No symbol table info available.
#8 0x000014950ca3ba9f in MPIDI_UCX_comm_set_vcis () from /home/xu.3304/project/mpix-stream/install/mpich-7f00e56-cuda12.6/lib/libmpi.so.0
No symbol table info available.
#9 0x000014950ca2000d in MPIDI_Comm_set_vcis () from /home/xu.3304/project/mpix-stream/install/mpich-7f00e56-cuda12.6/lib/libmpi.so.0
No symbol table info available.
#10 0x000014950ca11f6a in MPID_Comm_commit_post_hook () from /home/xu.3304/project/mpix-stream/install/mpich-7f00e56-cuda12.6/lib/libmpi.so.0
No symbol table info available.
#11 0x000014950c9750b4 in MPIR_Comm_commit () from /home/xu.3304/project/mpix-stream/install/mpich-7f00e56-cuda12.6/lib/libmpi.so.0
No symbol table info available.
#12 0x000014950c9731d7 in MPIR_init_comm_world () from /home/xu.3304/project/mpix-stream/install/mpich-7f00e56-cuda12.6/lib/libmpi.so.0
No symbol table info available.
#13 0x000014950c99e725 in MPII_Init_thread () from /home/xu.3304/project/mpix-stream/install/mpich-7f00e56-cuda12.6/lib/libmpi.so.0
No symbol table info available.
#14 0x000014950c7f0d80 in PMPI_Init_thread () from /home/xu.3304/project/mpix-stream/install/mpich-7f00e56-cuda12.6/lib/libmpi.so.0
No symbol table info available.
#15 0x00001494f1776582 in operator() (__closure=0x7ffca4fa79ff) at /home/xu.3304/project/mpix-stream/pytorch-mpich/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp:485
mpiStatus = -1527087723
mpi_was_initialized = 0
__func__ = "operator()"
#16 0x00001494f1776c75 in c10d::ProcessGroupMPI::initMPIOnce () at /home/xu.3304/project/mpix-stream/pytorch-mpich/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp:502
init_mpi_flag = false
#17 0x00001494f1776cd7 in c10d::ProcessGroupMPI::createProcessGroupMPI (ranks=...) at /home/xu.3304/project/mpix-stream/pytorch-mpich/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp:508
groupComm = 0
rank = 1716397447
size = 5267
__func__ = "createProcessGroupMPI"
The fulll log is attached gdbbt.mpich.pytorch.log
There's also a separate ucx log attached mpich-test-ucx.a100-10.192599.log
tested with mpich tests, No Errors
Forcing MPI_THREAD_MULTIPLE didn't work.
I am suspecting that this might be a pytorch problem given the history: here but I am hoping to get some intuition from maintainers or users who may have encountered similar issues, to better understand what might be happening here. Even some debugging tips that could lead to the root cause is heavily appreciated.
Please let me know if you need more context!