jax
jax copied to clipboard
Enable opt-in autodetection of distributed configuration
This PR is in response to the discussion on #19409.
It does the following:
- First, it adds an additional cluster environment for jax.distributed that is based on autodetection of rank, size, local_rank/devices, and coordination address based on mpi4py. The motivation here is that for clusters with mpi4py, the parameters needed for
jax.distributed.initialize
can be entirely inferred frommpi4py
provided the job is launched in a way compatible withMPI
. - Because of the constraint above, compatibility with
MPI
, this autodetect method is exclusively opt-in. Users must passspec_detection_method="mpi4py"
in a call tojax.distributed.initialize
. - Consistent with the behavior of all other initialization methods, options passed to
jax.distributed.initialize
for coordinator_address, etc., will override and auto-detected settings frommpi4py
. - Because of the new option in the arguments to
jax.distributed.initialize
, I have updated the documentation accordingly. - Lastly, related to using
jax.distributed.initialize
on HPC systems, there is sometimes a hang (See #9582). I have included a warning if any of the suspect variables are detected, and updated the documentation to point out that the user may want to unset these variables if they are on an HPC cluster. I suspect the warning will be viewed as to noisy, since I don't know the default log level in JAX off the top of my head. In this case, perhaps it can only be emitted if the TimeOut occurs, or at the very least maintained in the documentation.
I hope this is helpful, and it would simplify our lives using JAX on supercomputers :).
Corey
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).
View this failed invocation of the CLA check for more information.
For the most up to date status, view the checks section at the bottom of the pull request.