jax
jax copied to clipboard
jax.lax.pcollective_broadcast
This is only a draft PR, and shouldn't be seriously reviewed until after https://github.com/openxla/xla/pull/8968 is merged into openXLA.
Add a new jax.lax op pcollective_broadcast that calls the new CollectiveBroadcast StableHlo op.
This new operation is intended to be able to lower directly to ncclBroadcast. https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/colls.html#ncclbroadcast.
Example usage of this operation:
@jax.jit
@partial(shard_map, mesh=...,
in_specs=P('i', None), out_specs=P('i', None), check_rep=False)
def f(a):
return jax.lax.pcollective_broadcast(a, 'i', 3)
x = jnp.arange(64).reshape((8, 8))
print(x)
print(f(x))
This prints
Array([[ 0, 1, 2, 3, 4, 5, 6, 7],
[ 8, 9, 10, 11, 12, 13, 14, 15],
[16, 17, 18, 19, 20, 21, 22, 23],
[24, 25, 26, 27, 28, 29, 30, 31],
[32, 33, 34, 35, 36, 37, 38, 39],
[40, 41, 42, 43, 44, 45, 46, 47],
[48, 49, 50, 51, 52, 53, 54, 55],
[56, 57, 58, 59, 60, 61, 62, 63]], dtype=int32)
Array([[24, 25, 26, 27, 28, 29, 30, 31],
[24, 25, 26, 27, 28, 29, 30, 31],
[24, 25, 26, 27, 28, 29, 30, 31],
[24, 25, 26, 27, 28, 29, 30, 31],
[24, 25, 26, 27, 28, 29, 30, 31],
[24, 25, 26, 27, 28, 29, 30, 31],
[24, 25, 26, 27, 28, 29, 30, 31],
[24, 25, 26, 27, 28, 29, 30, 31]], dtype=int32)
as the row x[3] == [24, 25, 26, 27, 28, 29, 30, 31] is broadcast to all devices.
@mattjj If you want to start looking at it go for it. All of the meat is there.
I need to add tests which requires backend support, but beyond that this is ready for full review.
I added a test.
Yes, I will add much more test coverage.
Do we wish to also discuss the name pcopy? Now is the time to do so. Pinging @nouiz who had opinions to share.
If you prefer to use the name pbroadcast, I'd be open to renaming our other pbroadcast! Should be easy since I expect there are ~no users of the existing pbroadcast.
Reasons to call the thing in this PR pbroadcast:
- it corresponds to the HLO name
- it corresponds to the traditional MPI name
Reasons not to: 3. nothing, nada, zilch
WDYT?
Good enough reasoning for me.
@mattjj should be g2g now
I removed the test in shard_map_test.py
I think the next step is for Google to review/approve. It would be great to have it merged Monday/Tuesday as we will present that at GTC Wednesday.
I'll make sure this gets into at least main by Wednesday!
There were a bunch of stupid tiny errors on my end. @mattjj I tested this on a 8 GPU node and the tests pass so we should be g2g now.