rccl
rccl copied to clipboard
RCCL deadlock discussion (was: increase NCCL_STEPS to match WARPSIZE/4)
Details
Work item: GH issue pending
What were the changes?
Increase the pipeline size (NCCL_STEPS) from 8 to 16 to maintain correct synchronization behavior
Why were the changes made?
On Frontier at OLCF with 8 x MI250X nodes, RCCL sporadically stalls, especially when executed on many (>~ 128) nodes, or when using the libfabrics plugin (aws-ofi-rccl). The proposed change adjusts the pipeline width in the 'simple' communication protocol to be compatible with the larger warp (wavefront) size on AMD, which is 64, compared to 32 on NVIDIA hardware.
How was the outcome achieved?
A debug trace with roc-gdb demonstrated that under stall conditions and during a NCCL broadcast, some threads on the first rank in the ring (GPU0) are hanging inside a busy wait loop (waitPeer) https://github.com/ROCm/rccl/blob/4237caad6934872f441212a346007cc063981d0e/src/device/prims_simple.h#L118-L127
A closer inspection revealed that GPU 0 was eight steps (NCCL_STEPS=8) ahead of GPU1 and waiting for it to be ready for itself to send, while GPU 1 was waiting on GPU 0 to receive, i.e. the two GPUs were found in a circular deadlock inside waitPeer, which should not be allowed by design. However, to ensure asynchronous progress (i.e., GPU 1 sending and receiving at the same time), the original NCCL design uses warp level synchronization and a FIFO pipeline width that matches the maximum number (8) of progress groups of four threads each (waitRecv, waitSend, postRecv, and postSend) that fit into a warp. If more than this number of groups fits into a warp, the warp-level synchronization will stall asynchronous progress made by another pipeline. By increasing NCCL_STEPS to 16, this guarantees that one wave front will exactly fit one pipeline.
Additional Documentation:
It may be necessary to assess the performance implications of this patch and re-tune collective parameters.
Regarding reproducer: On Frontier, the stalling behavior is also modulated by the libfabrics version. I had the most "luck" with libfabric 1.20.1 w/ CXI provider, rocm 6.2.4, and aws-ofi-rccl commit 17d41cbf5618536c4b1076d29748416ab307040f I have a local reproducer code I shared with AMD developers, and which uses jax, but it is not yet minimal.
Approval Checklist
Do not approve until these items are satisfied.
- [ ] Verify the CHANGELOG has been updated, if
- there are any NCCL API version changes,
- any changes impact library users, and/or
- any changes impact any other ROCm library.
Thanks you @jglaser ! This is a valuable contribution.
@jglaser While I believe the deadlock you mentioned could be real, I am not convinced on the cause being WARP_SIZE. @nicholasmalaya can you help arrange a call to discuss? I think I still have Jens email address as we exchanged some emails 3 years ago on a different issue
@jglaser While I believe the deadlock you mentioned could be real, I am not convinced on the cause being WARP_SIZE. @nicholasmalaya can you help arrange a call to discuss? I think I still have Jens email address as we exchanged some emails 3 years ago on a different issue
Yes. I will reach out for us to discuss this in more detail.
@jglaser can you help dump proxy state with below patch to confirm reason of sender/receiver being stuck?
diff --git a/src/transport/net.cc b/src/transport/net.cc
index 5fd36f6f..16f07ba5 100644
--- a/src/transport/net.cc
+++ b/src/transport/net.cc
@@ -1379,6 +1379,7 @@ static ncclResult_t sendProxyProgress(struct ncclProxyState* proxyState, struct
if (sub->done == sub->nsteps) *sendHead = sub->base + args->sliceSteps;
} else {
*sendHead = sub->base + sub->done;
+ INFO(NCCL_COLL, "Send chan %d Posted %lx to %p", resources->channelId, sub->base + sub->done, sendHead);
}
if (resources->gdcSync) wc_store_fence(); // Flush out WC write
}
@@ -1662,8 +1663,10 @@ static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct
if (sub->reg) {
// We may have added more net steps, but reg operations only have a single step w.r.t. the GPU.
if (sub->transmitted == sub->nsteps) *recvTail = sub->base + args->sliceSteps;
- } else
+ } else {
*recvTail = sub->base + sub->transmitted;
+ INFO(NCCL_COLL, "Recv chan %d Posted %lx to %p", resources->channelId, sub->base + sub->transmitted, recvTail);
+ }
if (resources->gdcSync) wc_store_fence(); // Flush out WC write
}
}
OK, I think the NCCL_STEPS change might actually be unnecessary, as the stalling behavior seems to be already improved (if not gone!) as of https://github.com/ROCm/rccl/commit/36343be84f653416cfe0399915db931b96ea636f, which is a squashed merge commit from upstream NCCL.
Interestingly, RCCL performance also seems to be significantly improved (>2x) since that commit.
To find out, I used git bisect with a simple reproducer (a shell script, really), which interrupts one of the GPUs in the ring using rocgdb for a few seconds. This simulates a slow rank or a RDMA bottleneck. By simultaneously monitoring GPU usage (watch rocm-smi) one should see all GPUs performing work, then GPU1 throttling to 0W and the other MI250x GPUs to around 120-130W, and finally all of them returning to the original power. If a deadlock occurs, GPU1 will only return to the other GPU's busy wait usage after continuation in rocgdb. For commits prior to https://github.com/ROCm/rccl/commit/36343be84f653416cfe0399915db931b96ea636f, the chance of deadlock is roughly 20% on 16 nodes (128 GPUs).
# save as wrapper.sh
if [ ${SLURM_PROCID} = 1 ]; then
${@} &
pid=${!}
sleep 20
my_pid=`pstree -p "${pid}" | awk -F'[()]' '{print $2; exit}'`
pstree -p "${pid}"
echo "Interrupting process ${my_pid}"
{ echo "t a a bt"
sleep 15
echo "cont"
echo "detach"
echo "quit"
} | \
rocgdb -p ${my_pid}
wait
else
"${@}"
fi
Use the wrapper with rccl-tests like so
LD_LIBRARY_PATH=<path to rccl.so>:${LD_LIBRARY_PATH} \
NCCL_NET_GDR_LEVEL=3 \
LD_LIBRARY_PATH=<path to aws-ofi-rccl>/src/.libs:${LD_LIBRARY_PATH} \
srun -N 16 --ntasks-per-node=8 -c 8 bash -c "source wrapper.sh build/broadcast_perf -n 5000" -b 1G
As the commit is unfortunately quite large, it would be interesting to pinpoint the source code line that caused the change in behavior.
@jglaser when there are large number of GPUs, it is hard to debug with rocgdb. For example, we may find GPU N is stuck in waiting for receive, but it is due to GPU N-1 didn't send. But GPU N-1 didn't send because it is wait for receive from GPU N-2... We need to have all the logs from entire ring to identify root of the problem. Latest develop branch has a lot of improvements in getting GPU kernel logs for this scenario, but CPU proxy side logging is still lacking
@jglaser when there are large number of GPUs, it is hard to debug with rocgdb. For example, we may find GPU N is stuck in waiting for receive, but it is due to GPU N-1 didn't send. But GPU N-1 didn't send because it is wait for receive from GPU N-2... We need to have all the logs from entire ring to identify root of the problem. Latest develop branch has a lot of improvements in getting GPU kernel logs for this scenario, but CPU proxy side logging is still lacking
Here is a stalled trace out_435756af02a976d9686e5f6bc8954727dc49a29b_lasthalf.txt.gz with the ROCM 6.2 commit
after the "Continuing" the GPUs are hanging
=========================================== ROCm System Management Interface ===========================================
===================================================== Concise Info =====================================================
Device Node IDs Temp Power Partitions SCLK MCLK Fan Perf PwrCap VRAM% GPU%
(DID, GUID) (Edge) (Avg) (Mem, Compute, ID)
========================================================================================================================
0 4 0x7408, 63582 38.0°C 131.0W N/A, N/A, 0 1700Mhz 1600Mhz 0% manual 560.0W 2% 100%
1 5 0x7408, 51740 42.0°C N/A N/A, N/A, 0 1700Mhz 1600Mhz 0% manual 0.0W 2% 100%
2 6 0x7408, 15961 33.0°C 129.0W N/A, N/A, 0 1700Mhz 1600Mhz 0% manual 560.0W 2% 100%
3 7 0x7408, 3099 36.0°C N/A N/A, N/A, 0 1700Mhz 1600Mhz 0% manual 0.0W 2% 100%
4 8 0x7408, 13395 38.0°C 129.0W N/A, N/A, 0 1700Mhz 1600Mhz 0% manual 560.0W 2% 100%
5 9 0x7408, 1553 41.0°C N/A N/A, N/A, 0 1700Mhz 1600Mhz 0% manual 0.0W 2% 100%
6 10 0x7408, 62036 34.0°C 126.0W N/A, N/A, 0 1700Mhz 1600Mhz 0% manual 560.0W 2% 100%
7 11 0x7408, 49174 38.0°C N/A N/A, N/A, 0 1700Mhz 1600Mhz 0% manual 0.0W 2% 100%
========================================================================================================================
================================================= End of ROCm SMI Log ==================================================
and a successful one with the latest HEAD out_ccb082074351b560bbce3e1cb8d9ae2045b7beac.txt.gz
Command:
NCCL_DEBUG_SUBSYS=COLL NCCL_DEBUG=info LD_LIBRARY_PATH=/lustre/orion/world-shared/stf006/glaser/rccl/build/debug/:${LD_LIBRARY_PATH} NCCL_NET_GDR_LEVEL=3 LD_LIBRARY_PATH=/lustre/orion/world-shared/stf006/glaser/aws-ofi-rccl/src/.libs:${LD_LIBRARY_PATH} srun -N 16 --ntasks-per-node=8 -c 8 bash -c "source wrapper.sh build/broadcast_perf -n 100"```
Here is a stalled trace out_435756af02a976d9686e5f6bc8954727dc49a29b_lasthalf.txt.gz with the ROCM 6.2 commit
and the first half, too, which include RCCL setup out_435756af02a976d9686e5f6bc8954727dc49a29b_firsthalf.txt.gz
Caveat: I have another example here (RL training), which freezes even with the ccb08 commit.. (in allgather)
@jglaser Increasing NCCL_STEPS makes GPU kernel side FIFO deeper. I am wondering if this somehow helps with network stack which has another FIFO. There is NCCL_OFI_MAX_REQUESTS in aws ofi plugin https://github.com/ROCm/aws-ofi-rccl/blob/17d41cbf5618536c4b1076d29748416ab307040f/include/nccl_ofi.h#L68 Can you try increase or decrease this number to see if it has any effect?
Another observation: using the reproducer as before with the failing commit 43575..., but when RCCL has been compiled with -O0 (no optimizations), the pipeline recovers from the interruption, which it does not with -O3. This hints at a subtle correctness issue like memory ordering, or starvation/occupancy effects (e.g. due to the spin loops)... investigating further...
Another observation: using the reproducer as before with the failing commit 43575..., but when RCCL has been compiled with
-O0(no optimizations), the pipeline recovers from the interruption, which it does not with-O3. This hints at a subtle correctness issue like memory ordering, or starvation/occupancy effects (e.g. due to the spin loops)... investigating further...
I should perhaps mention that I replaced the thread barrier in that commit with my own, but am not sure if it that led to to the optimization level dependent behavior (this one avoids potential memory ordering issues by entirely relying on atomicCAS)
template<typename T>
__device__ T atomicBarrierCAS(T *address, int numThreads, int expectedPhase) {
T old = *address; // Read the initial state once before entering the loop
T assumed;
do {
assumed = (expectedPhase << 31) | (old & 0x7FFFFFFF); // Ensure phase is expected
int counter = assumed & 0x7FFFFFFF; // Extract counter (lower 31 bits)
int phase = (assumed >> 31) & 1; // Extract phase bit (highest bit)
// Determine the next phase, but don't use it for the increment/decrement
int nextPhase = phase;
if ((phase == 0 && counter == numThreads - 1) || (phase == 1 && counter == 1)) {
nextPhase = 1 - phase; // Flip phase for the next barrier
}
// Increment or decrement based on the original phase, not nextPhase
int newCounter = (phase == 0) ? (counter + 1) : (counter - 1);
// Construct new packed value
T newValue = (nextPhase << 31) | newCounter;
old = atomicCAS(address, assumed, newValue); // Attempt to update
} while (old != assumed); // Retry if atomicCAS failed
return old; // Return the final observed value
}
#define barrier_by_group() { \
const int wid = threadIdx.x%WARP_SIZE; \
if (wid == 0) { \
uint64_t num_leaders = nthreads/WARP_SIZE; \
atomicBarrierCAS(barriers, num_leaders, 0); \
atomicBarrierCAS(barriers, num_leaders, 1); \
} \
}
@jglaser if you are testing collectives like all reduce and broadcast etc, barrier_by_group() will call __builtin_amdgcn_s_barrier(), because RCCL will not reduce nthreads from NCCL_MAX_NTHREADS which is 256.
I think I found the issue of what was causing the hang in my application: multiple (N/R)CCL communicators used by multiple CPU threads. It is well known that NCCL is not thread safe, see e.g. https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#using-multiple-nccl-communicators-concurrently
When I analyzed the call pattern from my app, I noticed that there was more than one comunnicator (intra+internode from FSPD HYBRID_SHARD). Each communicator had its own HIP stream associated with it. However, these communicators were shared between different host threads. That creates a danger! Because the launch order on the host thread is random, whereas the GPU synchronizes (or, serializes) kernels launched to the same stream. Therefore, deadlocks are expected.
I confirmed the thread unsafe behavior of RCCL with this code
#include <iostream>
#include <thread>
#include <chrono>
#include <vector>
#include <cstdlib>
#include <sstream>
#include <rccl/rccl.h>
#include <hip/hip_runtime.h>
#include <mpi.h>
#define CHECK_HIP(call) \
do { \
hipError_t err = call; \
if (err != hipSuccess) { \
std::cerr << "HIP Error: " << hipGetErrorString(err) << "\n"; \
MPI_Abort(MPI_COMM_WORLD, 1); \
} \
} while (0)
#define CHECK_RCCL(call) \
do { \
ncclResult_t res = call; \
if (res != ncclSuccess) { \
std::cerr << "RCCL Error: " << ncclGetErrorString(res) << "\n"; \
MPI_Abort(MPI_COMM_WORLD, 1); \
} \
} while (0)
void run_nccl_op(ncclComm_t comm, hipStream_t stream, int device, int count, int delay_us, int thread_id, int world_rank) {
CHECK_HIP(hipSetDevice(device));
float *sendbuf, *recvbuf;
CHECK_HIP(hipMalloc(&sendbuf, count * sizeof(float)));
CHECK_HIP(hipMalloc(&recvbuf, count * sizeof(float)));
std::this_thread::sleep_for(std::chrono::microseconds(delay_us));
CHECK_RCCL(ncclAllGather(sendbuf, recvbuf, count, ncclFloat, comm, stream));
CHECK_HIP(hipFree(sendbuf));
CHECK_HIP(hipFree(recvbuf));
if (world_rank == 0) {
std::ostringstream msg;
msg << "Thread " << thread_id << " completed on rank 0\n";
std::cerr << msg.str();
}
}
int main(int argc, char* argv[]) {
MPI_Init(&argc, &argv);
int world_size, world_rank;
MPI_Comm_size(MPI_COMM_WORLD, &world_size);
MPI_Comm_rank(MPI_COMM_WORLD, &world_rank);
int device = 0; // use fixed device, rely on ROCR_VISIBLE_DEVICES
int count = 1024;
int num_threads = 2;
if (argc > 1) {
num_threads = std::atoi(argv[1]);
}
ncclComm_t comm;
hipStream_t stream;
CHECK_HIP(hipSetDevice(device));
CHECK_HIP(hipStreamCreate(&stream));
ncclUniqueId id;
if (world_rank == 0) {
ncclGetUniqueId(&id);
}
MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD);
CHECK_RCCL(ncclCommInitRank(&comm, world_size, id, world_rank));
std::vector<std::thread> threads;
for (int i = 0; i < num_threads; ++i) {
threads.emplace_back(run_nccl_op, comm, stream, device, count, i * 10, i, world_rank);
}
for (auto& t : threads) t.join();
CHECK_HIP(hipStreamSynchronize(stream));
ncclCommDestroy(comm);
CHECK_HIP(hipStreamDestroy(stream));
std::cout << "Rank " << world_rank << " completed" << std::endl;
MPI_Finalize();
return 0;
}
RCCL thread-unsafe reproducer
This demonstration confirms that launching RCCL collectives from the same GPU stream, but from different host threads creates non-deterministic results.
module load rocm/6.2.4
hipcc -o rccl_race_test rccl_race_test.cu -lrccl -L${MPICH_DIR}/lib -lmpi ${CRAY_XPMEM_POST_LINK_OPTS} -I${MPICH_DIR}/include
# interactive job with 1 gpu/process (ROCR_VISIBLE_DEVICES=0) on 1 node
salloc -N 1 -t 02:00:00 -A <proid> -p batch -q debug --cpus-per-task=8 -S 0 --tasks-per-node=8 --gpus-per-task=1
# run 1 w/4 threads
glaser@frontier10425:/lustre/orion/world-shared/stf006/glaser/twenty_questions> srun -n 8 rccl_race_test 4
Thread 3 completed on rank 0
Thread 1 completed on rank 0
Thread 0 completed on rank 0
Thread 2 completed on rank 0
Rank 4 completed
Rank 6 completed
Rank 2 completed
Rank 1 completed
Rank 3 completed
Rank 5 completed
Rank 7 completed
Rank 0 completed
# run 2
glaser@frontier10425:/lustre/orion/world-shared/stf006/glaser/twenty_questions> srun -n 8 rccl_race_test 4
Rank 1 completed
^Csrun: interrupt (one more within 1 sec to abort)
srun: StepId=3231174.44 tasks 0-7: running
^Csrun: sending Ctrl-C to StepId=3231174.44
srun: forcing job termination
^C^Csrun: sending Ctrl-C to StepId=3231174.44
srun: Job step aborted: Waiting up to 32 seconds for job step to finish.
srun: Terminating StepId=3231174.44
srun: job abort in progress
# baseline: 1 CPU thread, 10 runs
glaser@frontier10425:/lustre/orion/world-shared/stf006/glaser/twenty_questions> for Z in `seq 0 10`; do srun -n 8 rccl_race_test 1; done
Thread 0 completed on rank 0
Rank 6 completed
Rank 0 completed
Rank 4 completed
Rank 2 completed
Rank 1 completed
Rank 5 completed
Rank 7 completed
Rank 3 completed
srun: Step created for StepId=3231264.2
Thread 0 completed on rank 0
Rank 0 completed
Rank 6 completed
Rank 4 completed
On the python side, I could confirm that PyTorch uses multiple threads to parallelize the backward pass
# torch_reproducer.py
import torch
from torch.nn import Linear
model = Linear(4, 4).cuda()
input = torch.randn(1, 4, device='cuda', requires_grad=True)
def custom_hook(grad):
import os, threading
print(f"[Grad Hook] PID {os.getpid()} Thread {threading.get_native_id()}")
return grad
output = model(input)
output.register_hook(custom_hook)
loss = output.sum()
loss.backward()
output
[Grad Hook] PID 3371415 Thread 3373242
Notice PID != Thread ID
Now the mitigation strategy became obvious:
- Put all collectives into the main thread
- Make sure streams synchronize between calls to different communicators
Torch has an option to turn off threading during autograd:
with set_multithreading_enabled(False):
# trainer.train() or model.backward() ....
To ensure that streams sync before and after collectives, a little more effort is needed: another context manager
import contextlib
import torch
import functools
import os
import torch.distributed as dist
# Get native thread ID
import ctypes
import threading
libc = ctypes.CDLL("libc.so.6")
gettid = libc.syscall
SYS_gettid = 186 # x86_64 Linux syscall number for gettid
@contextlib.contextmanager
def patch_distributed_collectives(label="DISTRIBUTED"):
patched = {}
dist_ops = ['all_reduce', 'reduce_scatter', 'all_gather', 'broadcast', 'all_gather_into_tensor', 'barrier',
'gather', 'scatter', 'all_to_all', 'send', 'isend', 'irecv', 'send_object_list',
'recv_object_list', 'batch_isend_irecv', 'broadcast_object_list', 'all_reduce', 'reduce',
'all_gather_object', 'gather_object', 'all_to_all_single', 'monitored_barrier' ]
def make_wrapper(opname, fn):
@functools.wraps(fn)
def wrapper(*args, **kwargs):
thread_id = threading.get_ident()
native_tid = gettid(SYS_gettid)
pid = os.getpid()
try:
rank = dist.get_rank()
except RuntimeError:
rank = -1
#print(f"[{label}] Rank {rank} PID {pid} TID {native_tid} (Python {thread_id}) called {opname}")
stream = torch.cuda.current_stream()
stream.synchronize()
try:
return fn(*args, **kwargs)
finally:
stream.synchronize()
return wrapper
try:
for opname in dist_ops:
if hasattr(dist, opname):
original = getattr(dist, opname)
patched[opname] = original
setattr(dist, opname, make_wrapper(opname, original))
yield
finally:
for opname, original in patched.items():
setattr(dist, opname, original)
# Example usage:
# with patch_distributed_collectives():
# output = model(input)
Using these two together fixes the hang for me. Performance may not be optimal and NCCL usage in torch autograd may have to be re-examined.
To ensure that streams sync before and after collectives, a little more effort is needed: another context manager
import contextlib import torch import functools import os import torch.distributed as dist # Get native thread ID import ctypes import threading libc = ctypes.CDLL("libc.so.6") gettid = libc.syscall SYS_gettid = 186 # x86_64 Linux syscall number for gettid @contextlib.contextmanager def patch_distributed_collectives(label="DISTRIBUTED"): patched = {} dist_ops = ['all_reduce', 'reduce_scatter', 'all_gather', 'broadcast', 'all_gather_into_tensor', 'barrier', 'gather', 'scatter', 'all_to_all', 'send', 'isend', 'irecv', 'send_object_list', 'recv_object_list', 'batch_isend_irecv', 'broadcast_object_list', 'all_reduce', 'reduce', 'all_gather_object', 'gather_object', 'all_to_all_single', 'monitored_barrier' ] def make_wrapper(opname, fn): @functools.wraps(fn) def wrapper(*args, **kwargs): thread_id = threading.get_ident() native_tid = gettid(SYS_gettid) pid = os.getpid() try: rank = dist.get_rank() except RuntimeError: rank = -1 #print(f"[{label}] Rank {rank} PID {pid} TID {native_tid} (Python {thread_id}) called {opname}") stream = torch.cuda.current_stream() stream.synchronize() try: return fn(*args, **kwargs) finally: stream.synchronize() return wrapper try: for opname in dist_ops: if hasattr(dist, opname): original = getattr(dist, opname) patched[opname] = original setattr(dist, opname, make_wrapper(opname, original)) yield finally: for opname, original in patched.items(): setattr(dist, opname, original) # Example usage: # with patch_distributed_collectives(): # output = model(input)Using these two together fixes the hang for me. Performance may not be optimal and NCCL usage in torch autograd may have to be re-examined.
@thananon (as per offline discussion) try this patch ... if it still hangs, the next escalation strategy would be to replace the above two calls to stream.synchronize() with torch.cuda.synchronize(), to synchronize the entire device and not just the current stream
@jglaser You found me here. Yes, I incorporated this patch in the latest job as you suggested. I hope we get good result back.
After fixing the first hang, my app was able to make it to an error message, which I was able to fix. Subsequently, however, it still sometimes hung.
This is a minimal version of the patch that expands the coverage of distributed collectives, and which also only has a single (necessary) synchronization point.
import contextlib
import torch
import functools
import threading
import os
import torch.distributed as dist
@contextlib.contextmanager
def patch_distributed_collectives(logging=False):
patched = {}
dist_ops = ['all_reduce', 'reduce_scatter', 'reduce_scatter_tensor',
'all_gather', 'broadcast', 'all_gather_into_tensor', 'barrier',
'gather', 'scatter', 'all_to_all', 'send', 'isend', 'irecv', 'send_object_list',
'recv_object_list', 'batch_isend_irecv', 'broadcast_object_list', 'reduce',
'all_gather_object', 'gather_object', 'all_to_all_single', 'monitored_barrier',
'_broadcast_coalesced']
def make_wrapper(label, opname, fn, logging=False):
stream = torch.cuda.Stream()
@functools.wraps(fn)
def wrapper(*args, **kwargs):
if logging:
thread_id = threading.get_ident()
native_tid = threading.current_thread().native_id
pid = os.getpid()
try:
rank = dist.get_rank()
except RuntimeError:
rank = -1
device = torch.cuda.current_device()
current_stream = torch.cuda.current_stream()
async_op = kwargs.get('async_op', False)
print(f"[{label}] Rank {rank} PID {pid} TID {native_tid} (Python {thread_id}) called {opname} on device {device} stream {current_stream} async_op {async_op}")
current_stream = torch.cuda.current_stream()
try:
return fn(*args, **kwargs)
finally:
current_stream.synchronize()
return wrapper
try:
for opname in dist_ops:
if hasattr(dist, opname):
original = getattr(dist, opname)
patched[opname] = original
setattr(dist, opname, make_wrapper('DISTRIBUTED', opname, original, logging))
yield
finally:
for opname, original in patched.items():
setattr(dist, opname, original)
# Example usage:
# with patch_distributed_collectives():
# output = model(input)
Using this version, and with autograd threads disabled as before, my app runs to completion on 64 nodes using rccl commit https://github.com/ROCm/rccl/commit/532f54c2444501b3655e65fbce6d00d4bfc19c0b (or actually, to the next OOM :)
The likely cause is that running concurrent collectives on different streams/communicators (e.g. intra/internode) still requires stream dependencies to be set, which may (or may not) be implemented in FSDP. The above sync should be unnecessary once NCCL upstream 2.26 is merged into RCCL, so that users can set NCCL_LAUNCH_ORDER_IMPLICIT. Then, they only need to ensure that the host side order of calls is consistent, e.g., by disabling threads.
https://github.com/pytorch/pytorch/issues/147729 seems related
@jglaser can you try building https://github.com/pytorch/pytorch/pull/148590 to see if it resolves the deadlock? Or do you not expect it to resolve the issue?
@jglaser can you try building pytorch/pytorch#148590 to see if it resolves the deadlock? Or do you not expect it to resolve the issue?
Will do -- and just FYI, I still did encounter a hang with my RL app during collectives, after switching to full BF16 training (instead of mixed).... will investigate @thananon
@jglaser Just realizing #148590 was relanded as https://github.com/pytorch/pytorch/commit/acf5139e57e5882e4526f784d35d4c52845f2fd4 due to merge conflict. So it's been in main since Mon Mar 31. No need to build that PR yourself, it should be in nightly wheels. Have you given a nightly wheel a try yet?
I think I found the issue of what was causing the hang in my application: multiple (N/R)CCL communicators used by multiple CPU threads. It is well known that NCCL is not thread safe, see e.g. https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#using-multiple-nccl-communicators-concurrently
When I analyzed the call pattern from my app, I noticed that there was more than one comunnicator (intra+internode from FSPD HYBRID_SHARD). Each communicator had its own HIP stream associated with it. However, these communicators were shared between different host threads. That creates a danger! Because the launch order on the host thread is random, whereas the GPU synchronizes (or, serializes) kernels launched to the same stream. Therefore, deadlocks are expected.
I confirmed the thread unsafe behavior of RCCL with this code
#include <iostream> #include <thread> #include <chrono> #include <vector> #include <cstdlib> #include <sstream> #include <rccl/rccl.h> #include <hip/hip_runtime.h> #include <mpi.h> #define CHECK_HIP(call) \ do { \ hipError_t err = call; \ if (err != hipSuccess) { \ std::cerr << "HIP Error: " << hipGetErrorString(err) << "\n"; \ MPI_Abort(MPI_COMM_WORLD, 1); \ } \ } while (0) #define CHECK_RCCL(call) \ do { \ ncclResult_t res = call; \ if (res != ncclSuccess) { \ std::cerr << "RCCL Error: " << ncclGetErrorString(res) << "\n"; \ MPI_Abort(MPI_COMM_WORLD, 1); \ } \ } while (0) void run_nccl_op(ncclComm_t comm, hipStream_t stream, int device, int count, int delay_us, int thread_id, int world_rank) { CHECK_HIP(hipSetDevice(device)); float *sendbuf, *recvbuf; CHECK_HIP(hipMalloc(&sendbuf, count * sizeof(float))); CHECK_HIP(hipMalloc(&recvbuf, count * sizeof(float))); std::this_thread::sleep_for(std::chrono::microseconds(delay_us)); CHECK_RCCL(ncclAllGather(sendbuf, recvbuf, count, ncclFloat, comm, stream)); CHECK_HIP(hipFree(sendbuf)); CHECK_HIP(hipFree(recvbuf)); if (world_rank == 0) { std::ostringstream msg; msg << "Thread " << thread_id << " completed on rank 0\n"; std::cerr << msg.str(); } } int main(int argc, char* argv[]) { MPI_Init(&argc, &argv); int world_size, world_rank; MPI_Comm_size(MPI_COMM_WORLD, &world_size); MPI_Comm_rank(MPI_COMM_WORLD, &world_rank); int device = 0; // use fixed device, rely on ROCR_VISIBLE_DEVICES int count = 1024; int num_threads = 2; if (argc > 1) { num_threads = std::atoi(argv[1]); } ncclComm_t comm; hipStream_t stream; CHECK_HIP(hipSetDevice(device)); CHECK_HIP(hipStreamCreate(&stream)); ncclUniqueId id; if (world_rank == 0) { ncclGetUniqueId(&id); } MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD); CHECK_RCCL(ncclCommInitRank(&comm, world_size, id, world_rank)); std::vector<std::thread> threads; for (int i = 0; i < num_threads; ++i) { threads.emplace_back(run_nccl_op, comm, stream, device, count, i * 10, i, world_rank); } for (auto& t : threads) t.join(); CHECK_HIP(hipStreamSynchronize(stream)); ncclCommDestroy(comm); CHECK_HIP(hipStreamDestroy(stream)); std::cout << "Rank " << world_rank << " completed" << std::endl; MPI_Finalize(); return 0; }RCCL thread-unsafe reproducer
This demonstration confirms that launching RCCL collectives from the same GPU stream, but from different host threads creates non-deterministic results.
module load rocm/6.2.4 hipcc -o rccl_race_test rccl_race_test.cu -lrccl -L${MPICH_DIR}/lib -lmpi ${CRAY_XPMEM_POST_LINK_OPTS} -I${MPICH_DIR}/include# interactive job with 1 gpu/process (ROCR_VISIBLE_DEVICES=0) on 1 node salloc -N 1 -t 02:00:00 -A <proid> -p batch -q debug --cpus-per-task=8 -S 0 --tasks-per-node=8 --gpus-per-task=1# run 1 w/4 threads glaser@frontier10425:/lustre/orion/world-shared/stf006/glaser/twenty_questions> srun -n 8 rccl_race_test 4 Thread 3 completed on rank 0 Thread 1 completed on rank 0 Thread 0 completed on rank 0 Thread 2 completed on rank 0 Rank 4 completed Rank 6 completed Rank 2 completed Rank 1 completed Rank 3 completed Rank 5 completed Rank 7 completed Rank 0 completed # run 2 glaser@frontier10425:/lustre/orion/world-shared/stf006/glaser/twenty_questions> srun -n 8 rccl_race_test 4 Rank 1 completed ^Csrun: interrupt (one more within 1 sec to abort) srun: StepId=3231174.44 tasks 0-7: running ^Csrun: sending Ctrl-C to StepId=3231174.44 srun: forcing job termination ^C^Csrun: sending Ctrl-C to StepId=3231174.44 srun: Job step aborted: Waiting up to 32 seconds for job step to finish. srun: Terminating StepId=3231174.44 srun: job abort in progress# baseline: 1 CPU thread, 10 runs glaser@frontier10425:/lustre/orion/world-shared/stf006/glaser/twenty_questions> for Z in `seq 0 10`; do srun -n 8 rccl_race_test 1; done Thread 0 completed on rank 0 Rank 6 completed Rank 0 completed Rank 4 completed Rank 2 completed Rank 1 completed Rank 5 completed Rank 7 completed Rank 3 completed srun: Step created for StepId=3231264.2 Thread 0 completed on rank 0 Rank 0 completed Rank 6 completed Rank 4 completed
I tried to run your code but I am getting
RCCL Error: invalid usage (run with NCCL_DEBUG=WARN for details)
RCCL Error: invalid usage (run with NCCL_DEBUG=WARN for details)
I tried to run your code but I am getting
RCCL Error: invalid usage (run with NCCL_DEBUG=WARN for details) RCCL Error: invalid usage (run with NCCL_DEBUG=WARN for details)
This error may be expected with more than one thread. How many did you use (first command line argument)?
Im closing this PR as this conversation has moved to a meeting and resolved.,