rccl icon indicating copy to clipboard operation
rccl copied to clipboard

RCCL deadlock discussion (was: increase NCCL_STEPS to match WARPSIZE/4)

Open jglaser opened this issue 8 months ago • 26 comments

Details

Work item: GH issue pending

What were the changes?
Increase the pipeline size (NCCL_STEPS) from 8 to 16 to maintain correct synchronization behavior

Why were the changes made?
On Frontier at OLCF with 8 x MI250X nodes, RCCL sporadically stalls, especially when executed on many (>~ 128) nodes, or when using the libfabrics plugin (aws-ofi-rccl). The proposed change adjusts the pipeline width in the 'simple' communication protocol to be compatible with the larger warp (wavefront) size on AMD, which is 64, compared to 32 on NVIDIA hardware.

How was the outcome achieved?
A debug trace with roc-gdb demonstrated that under stall conditions and during a NCCL broadcast, some threads on the first rank in the ring (GPU0) are hanging inside a busy wait loop (waitPeer) https://github.com/ROCm/rccl/blob/4237caad6934872f441212a346007cc063981d0e/src/device/prims_simple.h#L118-L127

A closer inspection revealed that GPU 0 was eight steps (NCCL_STEPS=8) ahead of GPU1 and waiting for it to be ready for itself to send, while GPU 1 was waiting on GPU 0 to receive, i.e. the two GPUs were found in a circular deadlock inside waitPeer, which should not be allowed by design. However, to ensure asynchronous progress (i.e., GPU 1 sending and receiving at the same time), the original NCCL design uses warp level synchronization and a FIFO pipeline width that matches the maximum number (8) of progress groups of four threads each (waitRecv, waitSend, postRecv, and postSend) that fit into a warp. If more than this number of groups fits into a warp, the warp-level synchronization will stall asynchronous progress made by another pipeline. By increasing NCCL_STEPS to 16, this guarantees that one wave front will exactly fit one pipeline.

Additional Documentation:
It may be necessary to assess the performance implications of this patch and re-tune collective parameters.

Regarding reproducer: On Frontier, the stalling behavior is also modulated by the libfabrics version. I had the most "luck" with libfabric 1.20.1 w/ CXI provider, rocm 6.2.4, and aws-ofi-rccl commit 17d41cbf5618536c4b1076d29748416ab307040f I have a local reproducer code I shared with AMD developers, and which uses jax, but it is not yet minimal.

Approval Checklist

Do not approve until these items are satisfied.

  • [ ] Verify the CHANGELOG has been updated, if
    • there are any NCCL API version changes,
    • any changes impact library users, and/or
    • any changes impact any other ROCm library.

jglaser avatar Mar 13 '25 19:03 jglaser

Thanks you @jglaser ! This is a valuable contribution.

nicholasmalaya avatar Mar 13 '25 20:03 nicholasmalaya

@jglaser While I believe the deadlock you mentioned could be real, I am not convinced on the cause being WARP_SIZE. @nicholasmalaya can you help arrange a call to discuss? I think I still have Jens email address as we exchanged some emails 3 years ago on a different issue

wenkaidu avatar Mar 13 '25 20:03 wenkaidu

@jglaser While I believe the deadlock you mentioned could be real, I am not convinced on the cause being WARP_SIZE. @nicholasmalaya can you help arrange a call to discuss? I think I still have Jens email address as we exchanged some emails 3 years ago on a different issue

Yes. I will reach out for us to discuss this in more detail.

nicholasmalaya avatar Mar 13 '25 20:03 nicholasmalaya

@jglaser can you help dump proxy state with below patch to confirm reason of sender/receiver being stuck?

diff --git a/src/transport/net.cc b/src/transport/net.cc
index 5fd36f6f..16f07ba5 100644
--- a/src/transport/net.cc
+++ b/src/transport/net.cc
@@ -1379,6 +1379,7 @@ static ncclResult_t sendProxyProgress(struct ncclProxyState* proxyState, struct
                 if (sub->done == sub->nsteps) *sendHead = sub->base + args->sliceSteps;
               } else {
                 *sendHead = sub->base + sub->done;
+                INFO(NCCL_COLL, "Send chan %d Posted %lx to %p", resources->channelId, sub->base + sub->done, sendHead);
               }
               if (resources->gdcSync) wc_store_fence(); // Flush out WC write
             }
@@ -1662,8 +1663,10 @@ static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct
               if (sub->reg) {
                 // We may have added more net steps, but reg operations only have a single step w.r.t. the GPU.
                 if (sub->transmitted == sub->nsteps) *recvTail = sub->base + args->sliceSteps;
-              } else
+              } else {
                 *recvTail = sub->base + sub->transmitted;
+                INFO(NCCL_COLL, "Recv chan %d Posted %lx to %p", resources->channelId, sub->base + sub->transmitted, recvTail);
+              }
               if (resources->gdcSync) wc_store_fence(); // Flush out WC write
             }
           }

wenkaidu avatar Mar 14 '25 16:03 wenkaidu

OK, I think the NCCL_STEPS change might actually be unnecessary, as the stalling behavior seems to be already improved (if not gone!) as of https://github.com/ROCm/rccl/commit/36343be84f653416cfe0399915db931b96ea636f, which is a squashed merge commit from upstream NCCL.

Interestingly, RCCL performance also seems to be significantly improved (>2x) since that commit.

To find out, I used git bisect with a simple reproducer (a shell script, really), which interrupts one of the GPUs in the ring using rocgdb for a few seconds. This simulates a slow rank or a RDMA bottleneck. By simultaneously monitoring GPU usage (watch rocm-smi) one should see all GPUs performing work, then GPU1 throttling to 0W and the other MI250x GPUs to around 120-130W, and finally all of them returning to the original power. If a deadlock occurs, GPU1 will only return to the other GPU's busy wait usage after continuation in rocgdb. For commits prior to https://github.com/ROCm/rccl/commit/36343be84f653416cfe0399915db931b96ea636f, the chance of deadlock is roughly 20% on 16 nodes (128 GPUs).

# save as wrapper.sh
if [ ${SLURM_PROCID} = 1 ]; then
    ${@} &
    pid=${!}
    sleep 20
    my_pid=`pstree -p "${pid}" |  awk -F'[()]' '{print $2; exit}'`
    pstree -p "${pid}"
    echo "Interrupting process ${my_pid}"
    {  echo "t a a bt"
       sleep 15
       echo "cont"
       echo "detach"
       echo "quit"
     } | \
    rocgdb -p ${my_pid}
    wait
else
    "${@}"
fi

Use the wrapper with rccl-tests like so

LD_LIBRARY_PATH=<path to rccl.so>:${LD_LIBRARY_PATH} \
NCCL_NET_GDR_LEVEL=3 \
LD_LIBRARY_PATH=<path to aws-ofi-rccl>/src/.libs:${LD_LIBRARY_PATH} \
srun -N 16 --ntasks-per-node=8 -c 8 bash -c "source wrapper.sh build/broadcast_perf -n 5000" -b 1G

As the commit is unfortunately quite large, it would be interesting to pinpoint the source code line that caused the change in behavior.

jglaser avatar Mar 17 '25 09:03 jglaser

@jglaser when there are large number of GPUs, it is hard to debug with rocgdb. For example, we may find GPU N is stuck in waiting for receive, but it is due to GPU N-1 didn't send. But GPU N-1 didn't send because it is wait for receive from GPU N-2... We need to have all the logs from entire ring to identify root of the problem. Latest develop branch has a lot of improvements in getting GPU kernel logs for this scenario, but CPU proxy side logging is still lacking

wenkaidu avatar Mar 17 '25 15:03 wenkaidu

@jglaser when there are large number of GPUs, it is hard to debug with rocgdb. For example, we may find GPU N is stuck in waiting for receive, but it is due to GPU N-1 didn't send. But GPU N-1 didn't send because it is wait for receive from GPU N-2... We need to have all the logs from entire ring to identify root of the problem. Latest develop branch has a lot of improvements in getting GPU kernel logs for this scenario, but CPU proxy side logging is still lacking

Here is a stalled trace out_435756af02a976d9686e5f6bc8954727dc49a29b_lasthalf.txt.gz with the ROCM 6.2 commit

after the "Continuing" the GPUs are hanging

=========================================== ROCm System Management Interface ===========================================
===================================================== Concise Info =====================================================
Device  Node  IDs              Temp    Power   Partitions          SCLK     MCLK     Fan  Perf    PwrCap  VRAM%  GPU%
              (DID,     GUID)  (Edge)  (Avg)   (Mem, Compute, ID)
========================================================================================================================
0       4     0x7408,   63582  38.0°C  131.0W  N/A, N/A, 0         1700Mhz  1600Mhz  0%   manual  560.0W  2%     100%
1       5     0x7408,   51740  42.0°C  N/A     N/A, N/A, 0         1700Mhz  1600Mhz  0%   manual  0.0W    2%     100%
2       6     0x7408,   15961  33.0°C  129.0W  N/A, N/A, 0         1700Mhz  1600Mhz  0%   manual  560.0W  2%     100%
3       7     0x7408,   3099   36.0°C  N/A     N/A, N/A, 0         1700Mhz  1600Mhz  0%   manual  0.0W    2%     100%
4       8     0x7408,   13395  38.0°C  129.0W  N/A, N/A, 0         1700Mhz  1600Mhz  0%   manual  560.0W  2%     100%
5       9     0x7408,   1553   41.0°C  N/A     N/A, N/A, 0         1700Mhz  1600Mhz  0%   manual  0.0W    2%     100%
6       10    0x7408,   62036  34.0°C  126.0W  N/A, N/A, 0         1700Mhz  1600Mhz  0%   manual  560.0W  2%     100%
7       11    0x7408,   49174  38.0°C  N/A     N/A, N/A, 0         1700Mhz  1600Mhz  0%   manual  0.0W    2%     100%
========================================================================================================================
================================================= End of ROCm SMI Log ==================================================

and a successful one with the latest HEAD out_ccb082074351b560bbce3e1cb8d9ae2045b7beac.txt.gz

Command:

NCCL_DEBUG_SUBSYS=COLL NCCL_DEBUG=info LD_LIBRARY_PATH=/lustre/orion/world-shared/stf006/glaser/rccl/build/debug/:${LD_LIBRARY_PATH} NCCL_NET_GDR_LEVEL=3 LD_LIBRARY_PATH=/lustre/orion/world-shared/stf006/glaser/aws-ofi-rccl/src/.libs:${LD_LIBRARY_PATH} srun -N 16 --ntasks-per-node=8 -c 8 bash -c "source wrapper.sh build/broadcast_perf -n 100"```

jglaser avatar Mar 17 '25 20:03 jglaser

Here is a stalled trace out_435756af02a976d9686e5f6bc8954727dc49a29b_lasthalf.txt.gz with the ROCM 6.2 commit

and the first half, too, which include RCCL setup out_435756af02a976d9686e5f6bc8954727dc49a29b_firsthalf.txt.gz

jglaser avatar Mar 17 '25 20:03 jglaser

Caveat: I have another example here (RL training), which freezes even with the ccb08 commit.. (in allgather)

jglaser avatar Mar 18 '25 00:03 jglaser

@jglaser Increasing NCCL_STEPS makes GPU kernel side FIFO deeper. I am wondering if this somehow helps with network stack which has another FIFO. There is NCCL_OFI_MAX_REQUESTS in aws ofi plugin https://github.com/ROCm/aws-ofi-rccl/blob/17d41cbf5618536c4b1076d29748416ab307040f/include/nccl_ofi.h#L68 Can you try increase or decrease this number to see if it has any effect?

wenkaidu avatar Mar 18 '25 16:03 wenkaidu

Another observation: using the reproducer as before with the failing commit 43575..., but when RCCL has been compiled with -O0 (no optimizations), the pipeline recovers from the interruption, which it does not with -O3. This hints at a subtle correctness issue like memory ordering, or starvation/occupancy effects (e.g. due to the spin loops)... investigating further...

jglaser avatar Mar 20 '25 18:03 jglaser

Another observation: using the reproducer as before with the failing commit 43575..., but when RCCL has been compiled with -O0 (no optimizations), the pipeline recovers from the interruption, which it does not with -O3. This hints at a subtle correctness issue like memory ordering, or starvation/occupancy effects (e.g. due to the spin loops)... investigating further...

I should perhaps mention that I replaced the thread barrier in that commit with my own, but am not sure if it that led to to the optimization level dependent behavior (this one avoids potential memory ordering issues by entirely relying on atomicCAS)

template<typename T>
__device__ T atomicBarrierCAS(T *address, int numThreads, int expectedPhase) {
    T old = *address;  // Read the initial state once before entering the loop
    T assumed;

    do {
        assumed = (expectedPhase << 31) | (old & 0x7FFFFFFF); // Ensure phase is expected

        int counter = assumed & 0x7FFFFFFF; // Extract counter (lower 31 bits)
        int phase = (assumed >> 31) & 1;    // Extract phase bit (highest bit)

        // Determine the next phase, but don't use it for the increment/decrement
        int nextPhase = phase;
        if ((phase == 0 && counter == numThreads - 1) || (phase == 1 && counter == 1)) {
            nextPhase = 1 - phase; // Flip phase for the next barrier
        }

        // Increment or decrement based on the original phase, not nextPhase
        int newCounter = (phase == 0) ? (counter + 1) : (counter - 1);

        // Construct new packed value
        T newValue = (nextPhase << 31) | newCounter;

        old = atomicCAS(address, assumed, newValue); // Attempt to update
    } while (old != assumed); // Retry if atomicCAS failed

    return old; // Return the final observed value
}

#define barrier_by_group() { \
    const int wid = threadIdx.x%WARP_SIZE; \
    if (wid == 0) { \
      uint64_t num_leaders = nthreads/WARP_SIZE; \
      atomicBarrierCAS(barriers, num_leaders, 0); \
      atomicBarrierCAS(barriers, num_leaders, 1); \
    } \
}

jglaser avatar Mar 20 '25 19:03 jglaser

@jglaser if you are testing collectives like all reduce and broadcast etc, barrier_by_group() will call __builtin_amdgcn_s_barrier(), because RCCL will not reduce nthreads from NCCL_MAX_NTHREADS which is 256.

wenkaidu avatar Mar 20 '25 19:03 wenkaidu

I think I found the issue of what was causing the hang in my application: multiple (N/R)CCL communicators used by multiple CPU threads. It is well known that NCCL is not thread safe, see e.g. https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#using-multiple-nccl-communicators-concurrently

When I analyzed the call pattern from my app, I noticed that there was more than one comunnicator (intra+internode from FSPD HYBRID_SHARD). Each communicator had its own HIP stream associated with it. However, these communicators were shared between different host threads. That creates a danger! Because the launch order on the host thread is random, whereas the GPU synchronizes (or, serializes) kernels launched to the same stream. Therefore, deadlocks are expected.

I confirmed the thread unsafe behavior of RCCL with this code

#include <iostream>
#include <thread>
#include <chrono>
#include <vector>
#include <cstdlib>
#include <sstream>
#include <rccl/rccl.h>
#include <hip/hip_runtime.h>
#include <mpi.h>

#define CHECK_HIP(call) \
    do { \
        hipError_t err = call; \
        if (err != hipSuccess) { \
            std::cerr << "HIP Error: " << hipGetErrorString(err) << "\n"; \
            MPI_Abort(MPI_COMM_WORLD, 1); \
        } \
    } while (0)

#define CHECK_RCCL(call) \
    do { \
        ncclResult_t res = call; \
        if (res != ncclSuccess) { \
            std::cerr << "RCCL Error: " << ncclGetErrorString(res) << "\n"; \
            MPI_Abort(MPI_COMM_WORLD, 1); \
        } \
    } while (0)

void run_nccl_op(ncclComm_t comm, hipStream_t stream, int device, int count, int delay_us, int thread_id, int world_rank) {
    CHECK_HIP(hipSetDevice(device));

    float *sendbuf, *recvbuf;
    CHECK_HIP(hipMalloc(&sendbuf, count * sizeof(float)));
    CHECK_HIP(hipMalloc(&recvbuf, count * sizeof(float)));

    std::this_thread::sleep_for(std::chrono::microseconds(delay_us));

    CHECK_RCCL(ncclAllGather(sendbuf, recvbuf, count, ncclFloat, comm, stream));

    CHECK_HIP(hipFree(sendbuf));
    CHECK_HIP(hipFree(recvbuf));

    if (world_rank == 0) {
        std::ostringstream msg;
        msg << "Thread " << thread_id << " completed on rank 0\n";
        std::cerr << msg.str();
    }
}

int main(int argc, char* argv[]) {
    MPI_Init(&argc, &argv);

    int world_size, world_rank;
    MPI_Comm_size(MPI_COMM_WORLD, &world_size);
    MPI_Comm_rank(MPI_COMM_WORLD, &world_rank);

    int device = 0; // use fixed device, rely on ROCR_VISIBLE_DEVICES
    int count = 1024;
    int num_threads = 2;

    if (argc > 1) {
        num_threads = std::atoi(argv[1]);
    }

    ncclComm_t comm;
    hipStream_t stream;

    CHECK_HIP(hipSetDevice(device));
    CHECK_HIP(hipStreamCreate(&stream));

    ncclUniqueId id;
    if (world_rank == 0) {
        ncclGetUniqueId(&id);
    }
    MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD);
    CHECK_RCCL(ncclCommInitRank(&comm, world_size, id, world_rank));

    std::vector<std::thread> threads;
    for (int i = 0; i < num_threads; ++i) {
        threads.emplace_back(run_nccl_op, comm, stream, device, count, i * 10, i, world_rank);
    }
    for (auto& t : threads) t.join();

    CHECK_HIP(hipStreamSynchronize(stream));

    ncclCommDestroy(comm);
    CHECK_HIP(hipStreamDestroy(stream));

    std::cout << "Rank " << world_rank << " completed" << std::endl;

    MPI_Finalize();
    return 0;
}

RCCL thread-unsafe reproducer

This demonstration confirms that launching RCCL collectives from the same GPU stream, but from different host threads creates non-deterministic results.

module load rocm/6.2.4
hipcc -o rccl_race_test rccl_race_test.cu -lrccl -L${MPICH_DIR}/lib -lmpi ${CRAY_XPMEM_POST_LINK_OPTS} -I${MPICH_DIR}/include
# interactive job with 1 gpu/process (ROCR_VISIBLE_DEVICES=0) on 1 node
salloc -N 1 -t 02:00:00 -A <proid> -p batch -q debug --cpus-per-task=8 -S 0 --tasks-per-node=8 --gpus-per-task=1
# run 1 w/4 threads
glaser@frontier10425:/lustre/orion/world-shared/stf006/glaser/twenty_questions> srun -n 8 rccl_race_test 4
Thread 3 completed on rank 0
Thread 1 completed on rank 0
Thread 0 completed on rank 0
Thread 2 completed on rank 0
Rank 4 completed
Rank 6 completed
Rank 2 completed
Rank 1 completed
Rank 3 completed
Rank 5 completed
Rank 7 completed
Rank 0 completed

# run 2
glaser@frontier10425:/lustre/orion/world-shared/stf006/glaser/twenty_questions> srun -n 8 rccl_race_test 4
Rank 1 completed
^Csrun: interrupt (one more within 1 sec to abort)
srun: StepId=3231174.44 tasks 0-7: running
^Csrun: sending Ctrl-C to StepId=3231174.44
srun: forcing job termination
^C^Csrun: sending Ctrl-C to StepId=3231174.44
srun: Job step aborted: Waiting up to 32 seconds for job step to finish.
srun: Terminating StepId=3231174.44
srun: job abort in progress
# baseline: 1 CPU thread, 10 runs
glaser@frontier10425:/lustre/orion/world-shared/stf006/glaser/twenty_questions> for Z in `seq 0 10`; do srun -n 8 rccl_race_test 1; done
Thread 0 completed on rank 0
Rank 6 completed
Rank 0 completed
Rank 4 completed
Rank 2 completed
Rank 1 completed
Rank 5 completed
Rank 7 completed
Rank 3 completed
srun: Step created for StepId=3231264.2
Thread 0 completed on rank 0
Rank 0 completed
Rank 6 completed
Rank 4 completed

jglaser avatar Mar 24 '25 05:03 jglaser

On the python side, I could confirm that PyTorch uses multiple threads to parallelize the backward pass

# torch_reproducer.py
import torch
from torch.nn import Linear

model = Linear(4, 4).cuda()
input = torch.randn(1, 4, device='cuda', requires_grad=True)

def custom_hook(grad):
    import os, threading
    print(f"[Grad Hook] PID {os.getpid()} Thread {threading.get_native_id()}")
    return grad

output = model(input)
output.register_hook(custom_hook)
loss = output.sum()
loss.backward()

output

[Grad Hook] PID 3371415 Thread 3373242

Notice PID != Thread ID

jglaser avatar Mar 24 '25 05:03 jglaser

Now the mitigation strategy became obvious:

  1. Put all collectives into the main thread
  2. Make sure streams synchronize between calls to different communicators

Torch has an option to turn off threading during autograd:

 with set_multithreading_enabled(False):
          # trainer.train() or model.backward() ....

jglaser avatar Mar 24 '25 05:03 jglaser

To ensure that streams sync before and after collectives, a little more effort is needed: another context manager

import contextlib
import torch
import functools

import os
import torch.distributed as dist

# Get native thread ID
import ctypes
import threading

libc = ctypes.CDLL("libc.so.6")
gettid = libc.syscall
SYS_gettid = 186  # x86_64 Linux syscall number for gettid

@contextlib.contextmanager
def patch_distributed_collectives(label="DISTRIBUTED"):
    patched = {}
    dist_ops = ['all_reduce', 'reduce_scatter', 'all_gather', 'broadcast', 'all_gather_into_tensor', 'barrier',
                'gather', 'scatter', 'all_to_all', 'send', 'isend', 'irecv', 'send_object_list',
                'recv_object_list', 'batch_isend_irecv', 'broadcast_object_list', 'all_reduce', 'reduce',
                'all_gather_object', 'gather_object', 'all_to_all_single', 'monitored_barrier' ]

    def make_wrapper(opname, fn):
        @functools.wraps(fn)
        def wrapper(*args, **kwargs):
            thread_id = threading.get_ident()
            native_tid = gettid(SYS_gettid)
            pid = os.getpid()
            try:
                rank = dist.get_rank()
            except RuntimeError:
                rank = -1
            #print(f"[{label}] Rank {rank} PID {pid} TID {native_tid} (Python {thread_id}) called {opname}")
            stream = torch.cuda.current_stream()
            stream.synchronize()
            try:
                return fn(*args, **kwargs)
            finally:
                stream.synchronize()

        return wrapper

    try:
        for opname in dist_ops:
            if hasattr(dist, opname):
                original = getattr(dist, opname)
                patched[opname] = original
                setattr(dist, opname, make_wrapper(opname, original))
        yield
    finally:
        for opname, original in patched.items():
            setattr(dist, opname, original)


# Example usage:
# with patch_distributed_collectives():
#     output = model(input)

Using these two together fixes the hang for me. Performance may not be optimal and NCCL usage in torch autograd may have to be re-examined.

jglaser avatar Mar 24 '25 05:03 jglaser

To ensure that streams sync before and after collectives, a little more effort is needed: another context manager

import contextlib
import torch
import functools

import os
import torch.distributed as dist

# Get native thread ID
import ctypes
import threading

libc = ctypes.CDLL("libc.so.6")
gettid = libc.syscall
SYS_gettid = 186  # x86_64 Linux syscall number for gettid

@contextlib.contextmanager
def patch_distributed_collectives(label="DISTRIBUTED"):
    patched = {}
    dist_ops = ['all_reduce', 'reduce_scatter', 'all_gather', 'broadcast', 'all_gather_into_tensor', 'barrier',
                'gather', 'scatter', 'all_to_all', 'send', 'isend', 'irecv', 'send_object_list',
                'recv_object_list', 'batch_isend_irecv', 'broadcast_object_list', 'all_reduce', 'reduce',
                'all_gather_object', 'gather_object', 'all_to_all_single', 'monitored_barrier' ]

    def make_wrapper(opname, fn):
        @functools.wraps(fn)
        def wrapper(*args, **kwargs):
            thread_id = threading.get_ident()
            native_tid = gettid(SYS_gettid)
            pid = os.getpid()
            try:
                rank = dist.get_rank()
            except RuntimeError:
                rank = -1
            #print(f"[{label}] Rank {rank} PID {pid} TID {native_tid} (Python {thread_id}) called {opname}")
            stream = torch.cuda.current_stream()
            stream.synchronize()
            try:
                return fn(*args, **kwargs)
            finally:
                stream.synchronize()

        return wrapper

    try:
        for opname in dist_ops:
            if hasattr(dist, opname):
                original = getattr(dist, opname)
                patched[opname] = original
                setattr(dist, opname, make_wrapper(opname, original))
        yield
    finally:
        for opname, original in patched.items():
            setattr(dist, opname, original)


# Example usage:
# with patch_distributed_collectives():
#     output = model(input)

Using these two together fixes the hang for me. Performance may not be optimal and NCCL usage in torch autograd may have to be re-examined.

@thananon (as per offline discussion) try this patch ... if it still hangs, the next escalation strategy would be to replace the above two calls to stream.synchronize() with torch.cuda.synchronize(), to synchronize the entire device and not just the current stream

jglaser avatar Mar 25 '25 22:03 jglaser

@jglaser You found me here. Yes, I incorporated this patch in the latest job as you suggested. I hope we get good result back.

thananon avatar Mar 25 '25 22:03 thananon

After fixing the first hang, my app was able to make it to an error message, which I was able to fix. Subsequently, however, it still sometimes hung.

This is a minimal version of the patch that expands the coverage of distributed collectives, and which also only has a single (necessary) synchronization point.

import contextlib
import torch
import functools
import threading

import os
import torch.distributed as dist

@contextlib.contextmanager
def patch_distributed_collectives(logging=False):
    patched = {}
    dist_ops = ['all_reduce', 'reduce_scatter', 'reduce_scatter_tensor',
                'all_gather', 'broadcast', 'all_gather_into_tensor', 'barrier',
                'gather', 'scatter', 'all_to_all', 'send', 'isend', 'irecv', 'send_object_list',
                'recv_object_list', 'batch_isend_irecv', 'broadcast_object_list', 'reduce',
                'all_gather_object', 'gather_object', 'all_to_all_single', 'monitored_barrier',
                '_broadcast_coalesced']

    def make_wrapper(label, opname, fn, logging=False):
        stream = torch.cuda.Stream()

        @functools.wraps(fn)
        def wrapper(*args, **kwargs):
            if logging:
                thread_id = threading.get_ident()
                native_tid = threading.current_thread().native_id
                pid = os.getpid()

                try:
                    rank = dist.get_rank()
                except RuntimeError:
                    rank = -1
                device = torch.cuda.current_device()
                current_stream = torch.cuda.current_stream()
                async_op = kwargs.get('async_op', False)
                print(f"[{label}] Rank {rank} PID {pid} TID {native_tid} (Python {thread_id}) called {opname} on device {device} stream {current_stream} async_op {async_op}")

            current_stream = torch.cuda.current_stream()
            try:
                return fn(*args, **kwargs)
            finally:
                current_stream.synchronize()

        return wrapper

    try:
        for opname in dist_ops:
            if hasattr(dist, opname):
                original = getattr(dist, opname)
                patched[opname] = original
                setattr(dist, opname, make_wrapper('DISTRIBUTED', opname, original, logging))
        yield
    finally:
        for opname, original in patched.items():
            setattr(dist, opname, original)

# Example usage:
# with patch_distributed_collectives():
#     output = model(input)

Using this version, and with autograd threads disabled as before, my app runs to completion on 64 nodes using rccl commit https://github.com/ROCm/rccl/commit/532f54c2444501b3655e65fbce6d00d4bfc19c0b (or actually, to the next OOM :)

The likely cause is that running concurrent collectives on different streams/communicators (e.g. intra/internode) still requires stream dependencies to be set, which may (or may not) be implemented in FSDP. The above sync should be unnecessary once NCCL upstream 2.26 is merged into RCCL, so that users can set NCCL_LAUNCH_ORDER_IMPLICIT. Then, they only need to ensure that the host side order of calls is consistent, e.g., by disabling threads.

jglaser avatar Mar 29 '25 01:03 jglaser

https://github.com/pytorch/pytorch/issues/147729 seems related

jglaser avatar Apr 03 '25 02:04 jglaser

@jglaser can you try building https://github.com/pytorch/pytorch/pull/148590 to see if it resolves the deadlock? Or do you not expect it to resolve the issue?

jeffdaily avatar Apr 07 '25 23:04 jeffdaily

@jglaser can you try building pytorch/pytorch#148590 to see if it resolves the deadlock? Or do you not expect it to resolve the issue?

Will do -- and just FYI, I still did encounter a hang with my RL app during collectives, after switching to full BF16 training (instead of mixed).... will investigate @thananon

jglaser avatar Apr 08 '25 13:04 jglaser

@jglaser Just realizing #148590 was relanded as https://github.com/pytorch/pytorch/commit/acf5139e57e5882e4526f784d35d4c52845f2fd4 due to merge conflict. So it's been in main since Mon Mar 31. No need to build that PR yourself, it should be in nightly wheels. Have you given a nightly wheel a try yet?

jeffdaily avatar Apr 08 '25 22:04 jeffdaily

I think I found the issue of what was causing the hang in my application: multiple (N/R)CCL communicators used by multiple CPU threads. It is well known that NCCL is not thread safe, see e.g. https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#using-multiple-nccl-communicators-concurrently

When I analyzed the call pattern from my app, I noticed that there was more than one comunnicator (intra+internode from FSPD HYBRID_SHARD). Each communicator had its own HIP stream associated with it. However, these communicators were shared between different host threads. That creates a danger! Because the launch order on the host thread is random, whereas the GPU synchronizes (or, serializes) kernels launched to the same stream. Therefore, deadlocks are expected.

I confirmed the thread unsafe behavior of RCCL with this code

#include <iostream>
#include <thread>
#include <chrono>
#include <vector>
#include <cstdlib>
#include <sstream>
#include <rccl/rccl.h>
#include <hip/hip_runtime.h>
#include <mpi.h>

#define CHECK_HIP(call) \
    do { \
        hipError_t err = call; \
        if (err != hipSuccess) { \
            std::cerr << "HIP Error: " << hipGetErrorString(err) << "\n"; \
            MPI_Abort(MPI_COMM_WORLD, 1); \
        } \
    } while (0)

#define CHECK_RCCL(call) \
    do { \
        ncclResult_t res = call; \
        if (res != ncclSuccess) { \
            std::cerr << "RCCL Error: " << ncclGetErrorString(res) << "\n"; \
            MPI_Abort(MPI_COMM_WORLD, 1); \
        } \
    } while (0)

void run_nccl_op(ncclComm_t comm, hipStream_t stream, int device, int count, int delay_us, int thread_id, int world_rank) {
    CHECK_HIP(hipSetDevice(device));

    float *sendbuf, *recvbuf;
    CHECK_HIP(hipMalloc(&sendbuf, count * sizeof(float)));
    CHECK_HIP(hipMalloc(&recvbuf, count * sizeof(float)));

    std::this_thread::sleep_for(std::chrono::microseconds(delay_us));

    CHECK_RCCL(ncclAllGather(sendbuf, recvbuf, count, ncclFloat, comm, stream));

    CHECK_HIP(hipFree(sendbuf));
    CHECK_HIP(hipFree(recvbuf));

    if (world_rank == 0) {
        std::ostringstream msg;
        msg << "Thread " << thread_id << " completed on rank 0\n";
        std::cerr << msg.str();
    }
}

int main(int argc, char* argv[]) {
    MPI_Init(&argc, &argv);

    int world_size, world_rank;
    MPI_Comm_size(MPI_COMM_WORLD, &world_size);
    MPI_Comm_rank(MPI_COMM_WORLD, &world_rank);

    int device = 0; // use fixed device, rely on ROCR_VISIBLE_DEVICES
    int count = 1024;
    int num_threads = 2;

    if (argc > 1) {
        num_threads = std::atoi(argv[1]);
    }

    ncclComm_t comm;
    hipStream_t stream;

    CHECK_HIP(hipSetDevice(device));
    CHECK_HIP(hipStreamCreate(&stream));

    ncclUniqueId id;
    if (world_rank == 0) {
        ncclGetUniqueId(&id);
    }
    MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD);
    CHECK_RCCL(ncclCommInitRank(&comm, world_size, id, world_rank));

    std::vector<std::thread> threads;
    for (int i = 0; i < num_threads; ++i) {
        threads.emplace_back(run_nccl_op, comm, stream, device, count, i * 10, i, world_rank);
    }
    for (auto& t : threads) t.join();

    CHECK_HIP(hipStreamSynchronize(stream));

    ncclCommDestroy(comm);
    CHECK_HIP(hipStreamDestroy(stream));

    std::cout << "Rank " << world_rank << " completed" << std::endl;

    MPI_Finalize();
    return 0;
}

RCCL thread-unsafe reproducer

This demonstration confirms that launching RCCL collectives from the same GPU stream, but from different host threads creates non-deterministic results.

module load rocm/6.2.4
hipcc -o rccl_race_test rccl_race_test.cu -lrccl -L${MPICH_DIR}/lib -lmpi ${CRAY_XPMEM_POST_LINK_OPTS} -I${MPICH_DIR}/include
# interactive job with 1 gpu/process (ROCR_VISIBLE_DEVICES=0) on 1 node
salloc -N 1 -t 02:00:00 -A <proid> -p batch -q debug --cpus-per-task=8 -S 0 --tasks-per-node=8 --gpus-per-task=1
# run 1 w/4 threads
glaser@frontier10425:/lustre/orion/world-shared/stf006/glaser/twenty_questions> srun -n 8 rccl_race_test 4
Thread 3 completed on rank 0
Thread 1 completed on rank 0
Thread 0 completed on rank 0
Thread 2 completed on rank 0
Rank 4 completed
Rank 6 completed
Rank 2 completed
Rank 1 completed
Rank 3 completed
Rank 5 completed
Rank 7 completed
Rank 0 completed

# run 2
glaser@frontier10425:/lustre/orion/world-shared/stf006/glaser/twenty_questions> srun -n 8 rccl_race_test 4
Rank 1 completed
^Csrun: interrupt (one more within 1 sec to abort)
srun: StepId=3231174.44 tasks 0-7: running
^Csrun: sending Ctrl-C to StepId=3231174.44
srun: forcing job termination
^C^Csrun: sending Ctrl-C to StepId=3231174.44
srun: Job step aborted: Waiting up to 32 seconds for job step to finish.
srun: Terminating StepId=3231174.44
srun: job abort in progress
# baseline: 1 CPU thread, 10 runs
glaser@frontier10425:/lustre/orion/world-shared/stf006/glaser/twenty_questions> for Z in `seq 0 10`; do srun -n 8 rccl_race_test 1; done
Thread 0 completed on rank 0
Rank 6 completed
Rank 0 completed
Rank 4 completed
Rank 2 completed
Rank 1 completed
Rank 5 completed
Rank 7 completed
Rank 3 completed
srun: Step created for StepId=3231264.2
Thread 0 completed on rank 0
Rank 0 completed
Rank 6 completed
Rank 4 completed

I tried to run your code but I am getting

RCCL Error: invalid usage (run with NCCL_DEBUG=WARN for details)
RCCL Error: invalid usage (run with NCCL_DEBUG=WARN for details)

kswirydo avatar Apr 10 '25 01:04 kswirydo

I tried to run your code but I am getting

RCCL Error: invalid usage (run with NCCL_DEBUG=WARN for details)
RCCL Error: invalid usage (run with NCCL_DEBUG=WARN for details)

This error may be expected with more than one thread. How many did you use (first command line argument)?

jglaser avatar Apr 21 '25 19:04 jglaser

Im closing this PR as this conversation has moved to a meeting and resolved.,

thananon avatar Nov 06 '25 14:11 thananon