ray
ray copied to clipboard
[Collective] Support for Jax Array Collective
Description
Existing CuPy Collective work out of the box with jax arrays either through a wrapper around cupy like in alpa
Use case
GPU to GPU collective ops for jax array on a ray cluster:
- as seen in alpa https://github.com/alpa-projects/alpa/blob/87850dad9bce97229eace74a71d76c8cc412df47/alpa/collective/worker_nccl_util_cupy.py#L225-L259
- other jax codebases that use ray and gpu (quiet a lot these days :) )
Hi @OrenLeung , i actually just chatted with Alpa team about the exact topic yesterday, and the latest status is I'm pushing to upstream ray collective for JAX in Ray repo to ensure Ray is the best framework to do distributed JAX in a cluster of GPUs. Both Alpa and Ray/Anyscale are aligned on this.
If you have particular workloads in your mind or want to share about your understanding of this community, I'm more than happy to learn more!
Hi @jiaodong , I am currently working on a project that requires collective communication on JAX. Currently, I am using mpi4jax since Ray currently does not support CC on JAX's array. I am interested in contributing to this feature, could you please share some information about the progress on this feature, and if it's possible for me to contribute? Thanks.
Hi @ntlm1686 thanks for your interest and willingness to contribute ! So the most active user of ray collective is currently Alpa that has both cupy and xla based implementations ready that we're using on regular basis. We've discussed and planning to upstream collective to Ray repo in the longer run, but we're laser focusing another priority right now (which you will hear from us soon)
If you're curious to get started asap I would go with https://github.com/alpa-projects/alpa/blob/main/alpa/collective/worker_nccl_util_cupy.py , essentially collectives based on a few primitives
NOTE: cupy based collective doesn't work for bf16 tensors on Ampere arch due to incompatibility with DLPack.
def to_signal_buffer(jax_tensor):
return jax_tensor_to_cupy(jax_tensor, take_ownership=True)
def xla_buffer_to_cupy(xla_buf, take_ownership=False):
"""Convert an xla buffer directly to cupy, w/o transitioning from jax
buffer."""
return cupy.fromDlpack(
xc._xla.buffer_to_dlpack_managed_tensor( # pylint: disable=protected-access
xla_buf,
take_ownership=take_ownership))
def cupy_to_xla_buffer(tensor):
"""Convert cupy tensors to XLA buffers."""
if isinstance(tensor, list):
return list(map(cupy_to_xla_buffer, tensor))
cpu_backend = xb.get_backend("cpu")
try:
gpu_backend = xb.get_backend("gpu")
except RuntimeError:
gpu_backend = None
buf = xc._xla.dlpack_managed_tensor_to_buffer( # pylint: disable=protected-access
tensor.toDlpack(), cpu_backend, gpu_backend)
return buf
def jax_tensor_to_cupy(tensors, take_ownership=False):
"""Convert a Jax DeviceArray to cupy tensor; zero copy."""
if isinstance(tensors, list):
return list(map(jax_tensor_to_cupy, tensors))
return cupy.fromDlpack(to_dlpack(tensors, take_ownership=take_ownership))
def cupy_to_jax_tensor(tensors):
"""Convert cupy tensors to JAX tensors."""
if isinstance(tensors, list):
return list(map(cupy_to_jax_tensor, tensors))
return from_dlpack(tensors.toDlpack())
I would wait until things are sorted out regarding governance model, stable CI/CD for this component moving forward.
Hi, I'm a bot from the Ray team :)
To help human contributors to focus on more relevant issues, I will automatically add the stale label to issues that have had no activity for more than 4 months.
If there is no further activity in the 14 days, the issue will be closed!
- If you'd like to keep the issue open, just leave any comment, and the stale label will be removed!
- If you'd like to get more attention to the issue, please tag one of Ray's contributors.
You can always ask for help on our discussion forum or Ray's public slack channel.
Hi again! The issue will be closed because there has been no more activity in the 14 days since the last message.
Please feel free to reopen or open a new issue if you'd still like it to be addressed.
Again, you can always ask for help on our discussion forum or Ray's public slack channel.
Thanks again for opening the issue!