xla
xla copied to clipboard
[XLA:GPU] Add participating groups to NCCL clique key to fix split hang
When using --xla_gpu_enable_nccl_comm_splitting=true, it is possible for a deadlock to occur if one or more subgroups of a split was already created and those devices reuse it from the clique map, while the other subgroups initiate a split and will wait forver for the rest of the devices to join. This was occuring in the JAX pmap_test as reported by @hawkinsp.
The code below reproduces the issue by first creating a ccommunicator with all devices [0, 1, 2, 3]. Next, we created a communicator [0, 1]. Then, we try to split [0, 1, 2, 3] ->[0, 1] and [2, 3]. On ranks 0 and 1, XLA will reuse the [0, 1] comm that was created earlier. However, ranks 2 and 3 will begin a NcclCommSplit. They will be stuck forever since ranks 0 and 1 don't also join the split.
To fix this, the key for the clique map now also includes the full set of groups across all devices. This ensures that the clique map lookup behavior will be consistent across all ranks and in this situations would prevent ranks 0 and 1 from reusing the earlier [0, 1] communicator.
import jax
import jax.numpy as jnp
from jax import lax
def create_comm_0_1():
x = jnp.arange(2*2).reshape(2, 2)
ans = jax.pmap(lambda x: jax.lax.psum(jax.lax.psum(x, 'i'), 'i'), in_axes=0, out_axes=None, axis_name='i')(x)
print(ans)
def create_comms_0_1_and_2_3():
x = jnp.arange(4*2).reshape(4, 2)
ans = jax.pmap(lambda x: jax.lax.psum(jax.lax.psum(x, 'i'), 'i', axis_index_groups=[[0, 1], [2, 3]]), in_axes=0, out_axes=None, axis_name='i')(x)
print(ans)
create_comm_0_1()
create_comms_0_1_and_2_3() # <--- Hangs without this PR!
I've also refactored a little so that GetNcclCliqueKey is always used when creating the nccl clique key to avoid some duplicated code.
@ezhulenev Could you please take a look when you have a chance?
@ezhulenev a reminder to review this PR.