jax icon indicating copy to clipboard operation
jax copied to clipboard

Make automatic distributed initialization work with Open MPI 5.x, PMIx, and PRRTE

Open EricHallahan opened this issue 2 years ago • 8 comments

Background

https://github.com/google/jax/pull/13929 introduced automatic JAX distributed initialization via the Open MPI Open Run-Time Environment (ORTE) layer and its orterun process launcher (also known by its many aliases mpirun, mpiexec, oshrun, shmemrun).

Upcoming Open MPI 5.x series releases do away with the previous ORTE infrastructure for one based around the PMIx standard via the OpenPMIx reference PMIx implementation and complimentary PMIx Reference Run-Time Environment (PRRTE); in Open MPI 5.x the mpirun/mpiexec launcher is simply a wrapper for the PRRTE prterun launcher.

PMIx and PRRTE has differing behavior to ORTE which makes the implementation introduced in https://github.com/google/jax/pull/13929 incompatible with Open MPI 5.x. With Open MPI 5.0 (now in its tenth release candidate) continuing to approach release, there seems to be value in preparing JAX for this change.

Considerations & Challenges

Continued compatibility with ORTE and orterun

The current implementation (as introduced in https://github.com/google/jax/pull/13929) is fully usable with Open MPI versions prior to 5.x, and it is important to maintain compatibility with these releases when introducing support for Open MPI 5.x. It is unclear to me whether it would be wiser to make the current implementation compatible with the PRRTE-based launcher, or to create a separate piece of code to handle it.

New behaviors

PMIx/PRRTE exposes relevant information differently than ORTE.

OMPI_VERSION=5.0.0rc10
OMPI_TOOL_NAME=mpirun
PRTE_LAUNCHED=1
PMIX_NAMESPACE=prterun-%{hostname}-%{num_job_id}@1
PMIX_RANK=0
PMIX_SERVER_URI41=prterun-%{hostname}-%{num_job_id}@0.0;tcp4://%{ip_v4_addr}:%{port}
PMIX_SERVER_URI4=prterun-%{hostname}-%{num_job_id}@0.0;tcp4://%{ip_v4_addr}:%{port}
PMIX_SERVER_URI3=prterun-%{hostname}-%{num_job_id}@0.0;tcp4://%{ip_v4_addr}:%{port}
PMIX_SERVER_URI2=prterun-%{hostname}-%{num_job_id}@0.0;tcp4://%{ip_v4_addr}:%{port}
PMIX_SERVER_URI21=prterun-%{hostname}-%{num_job_id}@0.0;tcp4://%{ip_v4_addr}:%{port}
PMIX_SERVER_TMPDIR=/tmp/prte.%{hostname}.%{uid}/dvm.%{num_job_id}
PMIX_SYSTEM_TMPDIR=/tmp
OMPI_COMM_WORLD_SIZE=1
OMPI_WORLD_SIZE=1
OMPI_MCA_num_procs=1
OMPI_COMM_WORLD_RANK=0
OMPI_COMM_WORLD_LOCAL_RANK=0
OMPI_COMM_WORLD_NODE_RANK=0
PMIX_HOSTNAME=%{hostname}
A selection of variables exposed to a process when launching with mpirun under Open MPI 5.x.
  • OMPI_MCA_orte_hnp_uri no longer exists, and the server URI is now instead exposed to the process via a family of PMIX_SERVER_URI environment variables (one for each supported version of the PMIx standard). This means that the current implementation is not activated at all by the PRRTE process launcher. Even if it was, the value of these variables are not the same as OMPI_MCA_orte_hnp_uri and require a meaningfully different handling: The identifier prior to the address is no longer exclusively numeric and instead is the base of the job namespace (exposed via PMIX_NAMESPACE), derived from the server tmpdir (exposed via PMIX_SERVER_TMPDIR), which itself is derived from the server hostname (exposed via PMIX_HOSTNAME) and the numeric job identifier.

  • Open MPI environment variables exposing the world size, world rank, and local rank are unchanged, but PMIx also exposes the world rank itself via PMIX_RANK.

  • Detecting if the process is launched with prterun is more convenient than with orterun: PRTE_LAUNCHED is set to 1.

EricHallahan avatar Feb 18 '23 21:02 EricHallahan

@EricHallahan are you interested in implementing the support? cc; @nvcastet

mjsML avatar Feb 21 '23 20:02 mjsML

@EricHallahan are you interested in implementing the support?

Prior to filing this issue, I made a patch to the existing implementation to make it work for Open MPI 5.x. I am willing to contribute it, but the question remains as to how to maintain support for earlier versions (something I didn't consider for my personal use) as they are going to remain in use for many years to come.

EricHallahan avatar Feb 25 '23 22:02 EricHallahan

@EricHallahan Thanks a lot for raising this issue and the thorough discussion! We detect current OMPI with ORTE via the presence of the env var OMPI_MCA_orte_hnp_uri: https://github.com/google/jax/blob/main/jax/_src/clusters/ompi_cluster.py#L28-L29 For OpenMPI with PRRTE you could contribute your patch as a new subclass of ClusterEnv, OMPI ORTE will not interfere since OMPI_MCA_orte_hnp_uri will not be defined with PRRTE. Does that make sense?

nvcastet avatar Feb 27 '23 14:02 nvcastet

That is certainly a valid option! I'll go ahead and try that.

EricHallahan avatar Feb 27 '23 14:02 EricHallahan

@EricHallahan could you contribute your patch, or maybe put it out there? It would be useful for me as well.

PhilipVinc avatar Aug 27 '24 14:08 PhilipVinc

(And do you know if there is a way to do the same for MPICH by any chance?)

PhilipVinc avatar Aug 27 '24 14:08 PhilipVinc

You can use auto-detection via mpi4py for that. See https://github.com/google/jax/pull/20174

nvcastet avatar Aug 27 '24 14:08 nvcastet

I know, but I don't want to. I get some errors with that...

PhilipVinc avatar Aug 27 '24 14:08 PhilipVinc