maxtext icon indicating copy to clipboard operation
maxtext copied to clipboard

Cannot see multiple GPUs when using Slurm (with proposed fix)

Open gabeweisz opened this issue 5 months ago • 0 comments

When using MaxText with slurm, our jobs only see one GPU per node because jax.distributed assumes one GPU per process when used with slurm (see the Jax docs.

This behavior can be overridden by passing local_device_ids to jax.distributed.initialize, so one way to fix this is to change initialize_jax_for_gpu as follows (max_utils.py line 243): def initialize_jax_for_gpu(): """Jax distributed initialize for GPUs.""" if os.environ.get("JAX_COORDINATOR_IP") is not None: coordinator_ip = str(os.getenv("JAX_COORDINATOR_IP")) coordinator_port = str(os.getenv("JAX_COORDINATOR_PORT")) device_list = {os.getenv("CUDA_VISIBLE_DEVICES")} if len(device_list) == 0: device_list = None jax.distributed.initialize( coordinator_address=f"{coordinator_ip}:{coordinator_port}", num_processes=int(os.getenv("NNODES")), process_id=int(os.getenv("NODE_RANK")), local_device_ids=device_list, ) max_logging.log(f"JAX global devices: {jax.devices()}")

This can probably use more robust error handling.

gabeweisz avatar Sep 04 '24 19:09 gabeweisz