xla icon indicating copy to clipboard operation
xla copied to clipboard

[XLA:GPU] Add participating groups to NCCL clique key to fix split hang

Open trevor-m opened this issue 1 year ago • 1 comments

When using --xla_gpu_enable_nccl_comm_splitting=true, it is possible for a deadlock to occur if one or more subgroups of a split was already created and those devices reuse it from the clique map, while the other subgroups initiate a split and will wait forver for the rest of the devices to join. This was occuring in the JAX pmap_test as reported by @hawkinsp.

The code below reproduces the issue by first creating a ccommunicator with all devices [0, 1, 2, 3]. Next, we created a communicator [0, 1]. Then, we try to split [0, 1, 2, 3] ->[0, 1] and [2, 3]. On ranks 0 and 1, XLA will reuse the [0, 1] comm that was created earlier. However, ranks 2 and 3 will begin a NcclCommSplit. They will be stuck forever since ranks 0 and 1 don't also join the split.

To fix this, the key for the clique map now also includes the full set of groups across all devices. This ensures that the clique map lookup behavior will be consistent across all ranks and in this situations would prevent ranks 0 and 1 from reusing the earlier [0, 1] communicator.

import jax
import jax.numpy as jnp
from jax import lax

def create_comm_0_1():
    x = jnp.arange(2*2).reshape(2, 2)
    ans = jax.pmap(lambda x: jax.lax.psum(jax.lax.psum(x, 'i'), 'i'), in_axes=0, out_axes=None, axis_name='i')(x)
    print(ans)

def create_comms_0_1_and_2_3():
    x = jnp.arange(4*2).reshape(4, 2)
    ans = jax.pmap(lambda x: jax.lax.psum(jax.lax.psum(x, 'i'), 'i', axis_index_groups=[[0, 1], [2, 3]]), in_axes=0, out_axes=None, axis_name='i')(x)
    print(ans)

create_comm_0_1()
create_comms_0_1_and_2_3() # <--- Hangs without this PR!

I've also refactored a little so that GetNcclCliqueKey is always used when creating the nccl clique key to avoid some duplicated code.

trevor-m avatar Aug 09 '24 22:08 trevor-m

@ezhulenev Could you please take a look when you have a chance?

trevor-m avatar Aug 09 '24 22:08 trevor-m

@ezhulenev a reminder to review this PR.

sgerrard avatar Aug 12 '24 22:08 sgerrard