jax icon indicating copy to clipboard operation
jax copied to clipboard

jax.lax.pcollective_broadcast

Open chaserileyroberts opened this issue 1 year ago • 2 comments

This is only a draft PR, and shouldn't be seriously reviewed until after https://github.com/openxla/xla/pull/8968 is merged into openXLA.

Add a new jax.lax op pcollective_broadcast that calls the new CollectiveBroadcast StableHlo op.

This new operation is intended to be able to lower directly to ncclBroadcast. https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/colls.html#ncclbroadcast.

Example usage of this operation:

@jax.jit
@partial(shard_map, mesh=..., 
    in_specs=P('i', None), out_specs=P('i', None), check_rep=False)
def f(a):
  return jax.lax.pcollective_broadcast(a, 'i', 3)

x = jnp.arange(64).reshape((8, 8))
print(x)
print(f(x))

This prints

Array([[ 0,  1,  2,  3,  4,  5,  6,  7],
       [ 8,  9, 10, 11, 12, 13, 14, 15],
       [16, 17, 18, 19, 20, 21, 22, 23],
       [24, 25, 26, 27, 28, 29, 30, 31],
       [32, 33, 34, 35, 36, 37, 38, 39],
       [40, 41, 42, 43, 44, 45, 46, 47],
       [48, 49, 50, 51, 52, 53, 54, 55],
       [56, 57, 58, 59, 60, 61, 62, 63]], dtype=int32)
Array([[24, 25, 26, 27, 28, 29, 30, 31],
       [24, 25, 26, 27, 28, 29, 30, 31],
       [24, 25, 26, 27, 28, 29, 30, 31],
       [24, 25, 26, 27, 28, 29, 30, 31],
       [24, 25, 26, 27, 28, 29, 30, 31],
       [24, 25, 26, 27, 28, 29, 30, 31],
       [24, 25, 26, 27, 28, 29, 30, 31],
       [24, 25, 26, 27, 28, 29, 30, 31]], dtype=int32)

as the row x[3] == [24, 25, 26, 27, 28, 29, 30, 31] is broadcast to all devices.

chaserileyroberts avatar Jan 31 '24 18:01 chaserileyroberts

@mattjj If you want to start looking at it go for it. All of the meat is there.

chaserileyroberts avatar Feb 12 '24 23:02 chaserileyroberts

I need to add tests which requires backend support, but beyond that this is ready for full review.

chaserileyroberts avatar Feb 21 '24 19:02 chaserileyroberts

I added a test.

chaserileyroberts avatar Mar 06 '24 21:03 chaserileyroberts

Yes, I will add much more test coverage. Do we wish to also discuss the name pcopy? Now is the time to do so. Pinging @nouiz who had opinions to share.

chaserileyroberts avatar Mar 06 '24 22:03 chaserileyroberts

If you prefer to use the name pbroadcast, I'd be open to renaming our other pbroadcast! Should be easy since I expect there are ~no users of the existing pbroadcast.

Reasons to call the thing in this PR pbroadcast:

  1. it corresponds to the HLO name
  2. it corresponds to the traditional MPI name

Reasons not to: 3. nothing, nada, zilch

WDYT?

mattjj avatar Mar 06 '24 22:03 mattjj

Good enough reasoning for me.

chaserileyroberts avatar Mar 07 '24 00:03 chaserileyroberts

@mattjj should be g2g now

chaserileyroberts avatar Mar 11 '24 20:03 chaserileyroberts

I removed the test in shard_map_test.py

chaserileyroberts avatar Mar 12 '24 21:03 chaserileyroberts

I think the next step is for Google to review/approve. It would be great to have it merged Monday/Tuesday as we will present that at GTC Wednesday.

nouiz avatar Mar 16 '24 00:03 nouiz

I'll make sure this gets into at least main by Wednesday!

mattjj avatar Mar 18 '24 16:03 mattjj

There were a bunch of stupid tiny errors on my end. @mattjj I tested this on a 8 GPU node and the tests pass so we should be g2g now.

chaserileyroberts avatar Mar 18 '24 21:03 chaserileyroberts