dask-mpi
dask-mpi copied to clipboard
stable MPI.COMM_WORLD for scaling out to hundreds of node
Previously, initialize()
allows creating MPI comm world after import distributed
from distributed import Client, Nanny, Scheduler
from distributed.utils import import_term
...
def initialize(
...):
if comm is None:
from mpi4py import MPI
comm = MPI.COMM_WORLD
However, as I tested on large-scale clusters with hundreds of nodes, as the number of workers increases, typically when it gets more than 32, initialize()
function will be stuck at:
# scheduler
comm.bcast(scheduler.address, root=0)
comm.Barrier()
and
# worker
scheduler_address = comm.bcast(None, root=0)
dask.config.set(scheduler_address=scheduler_address)
comm.Barrier()
Because some worker processes fail to receive the bcast'd scheduler address.
After a long time of debugging, I found that this is strongly related to the order of import mpi4py
and import distributed
(or from distributed import
). I am guessing that in distributed
, some communication environment settings are made first which then leads to some conflicts when mpi4py
tries to bootstrap the MPI.COMM_WORLD after it.
By strictly requiring the user to create the MPI.COMM_WORLD before calling the initialize()
function, the above problem no longer bothers. According to my test, it can scale out to more than 128 workers (maybe more, as my resource is limited) without any hanging issues.
Checks are failing because new changes require calling initialize
with defined comm
, as initialize(comm=MPI.COMM_WORLD)
@YJHMITWEB: You are proposing that comm=None
(the current default) be entirely disallowed, even though it works in many setups. This seems a bit restrictive to me. Is it fair to say that the documentation should mention that setting the comm
parameter to an existing mpi4py.MPI.Intracomm
can prevent hanging on some systems when scaling to large numbers of MPI ranks? Perhaps this should not be such a Draconian modification and more of a documentation update.
Another modification that I might be in favor of is keep the current design but detect the comm.get_size()
after the internal comm
is created (i.e., in the comm=None
case). If the comm.get_size()
is larger than the limit you have tested (e.g., above 32), then issue a warning to the user telling the user why they might be experiencing hanging (and to tell them the fix).