jax
jax copied to clipboard
Proposal for generic mpi4py initialization of jax distributed module
jax.distributed.initiallize()
works, without arguments, on several but not all common MPI / Slurm parallel job launchers. Unfortunately, the environment variables used in the ompi_cluster.py
class are not standardized. I'd like to submit a PR that uses mpi4py
in a generic way to initialize the distributed mode automatically. A proposed file in jax/_src/clusters/
would look like this:
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from jax._src import clusters
import socket
from numpy import unique, where
from importlib.util import find_spec
class Mpi4pyCluster(clusters.ClusterEnv):
@classmethod
def is_env_present(cls) -> bool:
# Relies on mpi4py:
return find_spec("mpi4py") is not None
@classmethod
def get_coordinator_address(cls) -> str:
# Using mpi4py, figure out rank 0 and it's hostname.
# Then broadcast the hostname and port.
from mpi4py import MPI
# Get the global communicator:
COMM_WORLD = MPI.COMM_WORLD
# On rank 0, get the hostname:
if COMM_WORLD.Get_rank() == 0:
# Order all the hostnames, and find unique ones
hostname = socket.gethostname()
host_ip = socket.gethostbyname(hostname)
else:
host_ip = None
# Broadcast the host_ip to all ranks:
host_ip = COMM_WORLD.bcast(host_ip, root=0)
# Apparently, we want to pick a port in an ephemeral range...
port_id = hash(host_ip) % 2**12 + (65535 - 2**12 + 1)
return f'{host_ip}:{port_id}'
@classmethod
def get_process_count(cls) -> int:
from mpi4py import MPI
return MPI.COMM_WORLD.Get_size()
@classmethod
def get_process_id(cls) -> int:
from mpi4py import MPI
return MPI.COMM_WORLD.Get_rank()
@classmethod
def get_local_process_id(cls) -> int | None:
# Using mpi4py, split the global communicator into sub communicators
# based on hostname. mpi will assign them ranks and that will allow
# a selection of the local process ID.
from mpi4py import MPI
COMM_WORLD = MPI.COMM_WORLD
hostname = socket.gethostname()
# host_key = host_key %
all_hostnames = COMM_WORLD.gather(hostname, root=0)
if COMM_WORLD.Get_rank() == 0:
# Order all the hostnames, and find unique ones
unique_hosts = unique(all_hostnames)
# Numpy automatically sorts them.
else:
unique_hosts = None
# Broadcast the list of hostnames:
unique_hosts = COMM_WORLD.bcast(unique_hosts, root=0)
# Find the integer for this host in the list of hosts:
i = int(where(unique_hosts == hostname)[0])
new_comm = COMM_WORLD.Split(color=i)
# The rank in the new communicator - which is host-local only - IS the local rank:
return int(new_comm.Get_rank())
Assuming you'd welcome this pull request, I had a question or two:
- How often are the functions
get_coordinator_address
andget_local_process_id
called? Each of them does some MPI ops which, at very large scales, might be detrimental to performance. Caching the results of the functions would help. - Assuming this proceeds, would it be preferable to have this as a separate class to the
ompi_clsuter.py
implementation using env variables? Or as a fallback? - Does the distributed initialization happen over TCP? If so, is it possible to bypass any of that? I expect that at the largest scales, this may break down performance.
- I wrap all of the
from mpi4py import MPI
calls inside the functions of the class. This line will raise an error if mpi4py is not installed. I'm assuming those functions will never get called ifis_env_present
returns False?
For some context, I work at Argonne National Laboratory and run JAX on our supercomputers. Currently we're running jobs with O(1000) Jax processes, on A100 gpus, but in the (near) future this will hopefully become O(100000) processes on Intel GPUs. I have been using mpi4jax for scaling (it's also great, probably you know about it already) but there are use cases for JAX's distributed package as well.
I fully support this effort!
This would remarkably simplify setup for several academic users. I've had to fight with setting up jax.distributed in weird clusters several times and this seems a great idea.
For what it's worth, I have done this and tested it on our Polaris Supercomputer at Argonne National Lab, the changes are pretty small, one additional file (Mpi4pyCluster.py) and one modified file.
https://github.com/coreyjadams/jax
Some open questions that I'm not clear on an answer to, exactly. Related to #9582 , the http proxy probably variables have to be unset to use the distributed.initialize functionality. It'd be nice to include a warning message in the timeout notifications, at least, to try again unsetting them? More aggressive options (unsetting the variables, raising warnings or exceptions preemptively) all have problems. But, I don't know where exactly the timeout error is emanating from, if it happens.
Would this PR be welcomed by the JAX team?
Ultimately, we (the JAX maintainers) aren't MPI users. So the MPI-using community will be the best judge of whether this approach works well! It looks plausible to me, and you should send a PR.
@nvcastet originally contributed that file, I think. Perhaps they have comments.
Would the mpi4py approach be superior in all cases? Could we just have the mpi4py version?
To improve the error message, I'd probably just stick a Python-level try
block around the .connect()
call, that perhaps looks for an HTTP proxy environment variable and warns. Or we could just warn right before the connect call, if we think that's likely always an error. The error itself originates from deep in the GRPC stack, so it's probably easiest to provide more information in the Python caller. Send a PR?
For some context, I work at Argonne National Laboratory and run JAX on our supercomputers. Currently we're running jobs with O(1000) Jax processes, on A100 gpus, but in the (near) future this will hopefully become O(100000) processes on Intel GPUs.
By the way, we'd love to hear more about your workload, always great to hear about people using JAX at scale. Do you use JAX's distributed jit
or shard_map
, or pmap
?
OK, I'll put a PR together and send it in. I thought of another question - is it likely to call these cluster functions more than once or twice at startup? At large scales, doing MPI broadcast and COMM_Split will incur some overhead that, if it's just at startup, is worth it. If it's happening often, it would be inefficient and I should cache the output of those functions.
The mpi4py approach is more generic but introduces an additional dependency (mpi4py
of course). I think it's worthwhile to leave the cluster implementations for the most common job managers (slurm, etc) to not force mpi4py on the users who are already happy. Since those approaches are just checking for the existence of ENV variables, it's not a significant overhead.
Perhaps a workable solution - since the MPI users are mostly from the big supercomputer facilities where this HTTP error will show up - is to emit a one-off warning from the Mpi4pyCluster file. This has the advantage of doing it after MPI init and we can control it to emit only from rank 0.
By the way, we'd love to hear more about your workload, always great to hear about people using JAX at scale. Do you use JAX's distributed
jit
orshard_map
, orpmap
?
Happy to share! We actually just release a public version this week: https://github.com/Nuclear-Physics-with-Machine-Learning/JAX_QMC_Public
The workload is doing variational monte carlo with quantum many body systems. We've actually tried this in torch, libtorch, julia, and tensorflow too! JAX is the winner for multiple reasons:
- We need 2nd derivatives and jacobians that were horribly inefficiency in torch when we started this. The situation might be different now that torch has vmap, but I haven't checked.
- Tensorflow was really good until we needed more complicated functions traced with xla and @tf.function - then I was seeing compile times longer than an hour and couldn't actually run jobs!
- libtorch was better than torch but even with some work optimizing it I couldn't improve the concurrency of the jacobian calculation that we need.
- Julia was just really frustrating to get performance out of. Static arrays was just ... not working how I expected it to. Certainly a user error :)
For scale out, I actually use mpi4jax in that repository above - it gives good performance and I've scaled this out to 2000+ A100s on our supercomputer. Our bottleneck at scale is an allreduce inside of a conjugate gradient solve, we're calling it a few hundred times per algorithm iteration and it starts to dominate. I actually presented this last fall, you can see the scaling plots around slide 24 here: https://www.alcf.anl.gov/sites/default/files/2023-10/ALCF-HandsOnWorkshop-VariationalQMC-Nuclei.pdf
My only real complaint about scaling out JAX was the inability to do MPI reductions in place - before the C.G. implementation for part of our workload, we used a Cholesky solve that needed a big (15+ GB) allreduce and having a separate buffer doubled our GPU VRAM usage. It's a big part of why we switched to C.G.
I'm working on a shard_map
alternative to mpi4jax too. I'm not sure which will be more efficient, it's worth testing both, and at the moment we're blocked in mpi4jax
with an unexpected bug: https://github.com/mpi4jax/mpi4jax/issues/229
Thanks! Corey
jax.distributed.initiallize() works, without arguments, on several but not all common MPI / Slurm parallel job launchers.
From what i remembered, slurm_cluster.py should work with all Slurm jobs independent of MPI/PMI* usage since slurm will set those env variables. Is it because you are using mpirun
/mpiexec
instead of srun
to launch your MPI application?
On the other hand, ompi_cluster.py
, was made for detecting OpenMPI (orte runtime, default one up to v5) applications launched via mpirun
/mpiexec
?
See https://github.com/google/jax/issues/14576 for v5.
Do you know which MPI distribution you running on your cluster?
We originally did not go the mpi4py
path because we did not want to initialize MPI under-the-cover without the user permission.
Applications can initialize MPI differently than mpi4py default init.
def is_env_present(cls) -> bool:
# Relies on mpi4py:
return find_spec("mpi4py") is not None
Here, the detection could be an issue since having the package installed in the environment does not mean you are launching a MPI application, therefore the job could have been launched with a non a mpi-compatible launcher (e.g. srun --mpi=none ...
).
I like the simplicity of mpi4py
but I am scared of enabling it fully-transparently with JAX for the reasons above.
A compromise would be to put Mpi4pyCluster
in a utils python file and the user can opt-in for it by just importing it.
Importing it before calling initialize()
will register it automatically as a known env to jax.
from myutils import Mpi4pyCluster
....
jax.distributed.initialize()
Another potential solution for mpi4py
users is to have mpi4jax
defines Mpi4pyCluster
at init time of the mpi4jax
module since mpi4jax
already has that hard dependency on mpi4py
anyway.
On our systems, we're using MPICH from HPE but in my experience when you are dealing with the vendor optimized mpi implementations, from HPE/Cray/Intel/IBM/etc - the env variables they set are different. So the OMPI variables are fine but not necessarily generic.
And yes - we're not launching with Slurm, the job scheduler on this cluster is PBSPro. I suspect we're running on very different clusters!
Do you think there is much overlap in the use case of mpi4jax
vs. jax.distributed
? If users want to use mpi4jax, they are likely to just use that; users of pmap
/shard_map
etc are not likely to call mpi4jax reductions on sharded tensors - unless there is some use case I'm missing? I agree it would be bad to limit users to one OR the other though, better to maintain both as viable options separate or together.
What do you think about an optional argument to jax.distributed.initialize
that allows the user to select an auto-init method? Legacy methods with just reading env variables can leave it blank, but it could also allow the user to prioritize one init method over another.
For example, a user on a slurm system with mpi4py installed could call jax.distributed.initialize(auto_init_method="slurm")
and know that mpi4py will not be initialized. Or, they could call jax.distributed.initialize(auto_init_method="mpi")
and force the use of mpi4py. Perhaps jax.distributed.initialize(auto_init_method="any")
could let JAX take the wheel and figure out the initialization parameters in whatever way works first?
Anyways, maybe the real answer is to just update the documentation with "If you're launching JAX on a cluster with MPI, here is a technique to pick suitable initialization parameters via mpi4py ... "
Hi Corey,
I agree mpi4jax
and jax.distributed
may be used together or separate.
I really like your mpi4py
approach to catch all the vendor-specific MPI implementations. Figuring out all the potential env variables that the different vendors are setting or not setting is a nightmare even if it would not add the mentioned extra downsides (extra dependency, initializing MPI, and extra communications).
The original ideal goal of jax.distributed.initialize()
[with no args] was to be able to run the same script on different environments without any code change (not even argument change) by doing auto-detection.
This implies that when implementing the def is_env_present(cls)
method of the environment class, if True
is returned, we know for sure that the application is running in that environment and we can get the job parameters from the environment to initialize jax.distributed
.
For mpi4py
, I think we agree the user will need to opt-in and says "I launched my job with a mpi-compatible launcher and i am good leveraging mpi4py to collect the info needed for jax.distributed".
For this new "opt-in" scenario, the user would need to be able to specify the method as you mentioned something like jax.distributed.initialize(spec_detection_method="mpi4py")
.
@hawkinsp and @skye what are your thoughts?
Just so we can have something clear to discuss: I opened a PR #20174 based on what we've talked about here, using an exclusively opt-in method. Hopefully it proves useful!