jax icon indicating copy to clipboard operation
jax copied to clipboard

Enable opt-in autodetection of distributed configuration

Open coreyjadams opened this issue 11 months ago • 8 comments

This PR is in response to the discussion on #19409.

It does the following:

  • First, it adds an additional cluster environment for jax.distributed that is based on autodetection of rank, size, local_rank/devices, and coordination address based on mpi4py. The motivation here is that for clusters with mpi4py, the parameters needed for jax.distributed.initialize can be entirely inferred from mpi4py provided the job is launched in a way compatible with MPI.
  • Because of the constraint above, compatibility with MPI, this autodetect method is exclusively opt-in. Users must pass spec_detection_method="mpi4py" in a call to jax.distributed.initialize.
  • Consistent with the behavior of all other initialization methods, options passed to jax.distributed.initialize for coordinator_address, etc., will override and auto-detected settings from mpi4py.
  • Because of the new option in the arguments to jax.distributed.initialize, I have updated the documentation accordingly.
  • Lastly, related to using jax.distributed.initialize on HPC systems, there is sometimes a hang (See #9582). I have included a warning if any of the suspect variables are detected, and updated the documentation to point out that the user may want to unset these variables if they are on an HPC cluster. I suspect the warning will be viewed as to noisy, since I don't know the default log level in JAX off the top of my head. In this case, perhaps it can only be emitted if the TimeOut occurs, or at the very least maintained in the documentation.

I hope this is helpful, and it would simplify our lives using JAX on supercomputers :).

Corey

coreyjadams avatar Mar 11 '24 17:03 coreyjadams

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

google-cla[bot] avatar Mar 11 '24 17:03 google-cla[bot]