jax icon indicating copy to clipboard operation
jax copied to clipboard

Slurm initialization only supports one device per host

Open Findus23 opened this issue 2 years ago • 7 comments

I have access to a HPC cluster with multiple nodes that each have two GPUs. As I want to do computations that require the memory access of many GPUs, I was looking into the multi-host setup.

The HPC cluster uses slurm, so using that for initialisation seems the easiest. I created a simple test script:

import os

import jax

jax.distributed.initialize()
jax.config.update("jax_enable_x64", True)

print(os.environ.get("CUDA_VISIBLE_DEVICES"))
print(jax.devices(), jax.local_devices())

And a slurm job script:

#!/bin/bash
#SBATCH --mail-type=ALL
#SBATCH --nodes=2
#SBATCH --tasks-per-node=1
#SBATCH --exclusive
#SBATCH --job-name=jax-multinode-test
#SBATCH --gpus=4

source $DATA/venv-jax/bin/activate
cd ~/jax-testing/
srun python distributed.py

I would then expect jax.devices() to be an array of 4 GPU devices. But the output is:

0,1
[gpu(id=0), gpu(id=1)] [gpu(id=0)]
0,1
[gpu(id=0), gpu(id=1)] [gpu(id=1)]

So each of the two hosts only contributes one local device.

If I comment out jax.distributed.initialize() and therefore let both hosts do their own thing, both hosts detect both GPUs properly:

0,1
[gpu(id=0), gpu(id=1)] [gpu(id=0), gpu(id=1)]
0,1
[gpu(id=0), gpu(id=1)] [gpu(id=0), gpu(id=1)]

Technically, this is documented here in the function arguments: https://github.com/google/jax/blob/f94104f71a041def61ea5b22676bbbecbfbe0a9b/jax/_src/distributed.py#L147-L149

But I am not sure what the reason for this limitation is as in my experience having two GPUs per host is quite common.

And patching https://github.com/google/jax/blob/f94104f71a041def61ea5b22676bbbecbfbe0a9b/jax/_src/clusters/slurm_cluster.py#L60-L62 to instead return None makes everything work the way one would expect:

0,1
[gpu(id=0), gpu(id=1), gpu(id=2), gpu(id=3)] [gpu(id=0), gpu(id=1)]
0,1
[gpu(id=0), gpu(id=1), gpu(id=2), gpu(id=3)] [gpu(id=2), gpu(id=3)]

(see https://github.com/google/jax/discussions/16789 for another issue I am having with doing this)

Findus23 avatar Jul 19 '23 15:07 Findus23

Same thing. I tried multiple jax version including 0.4.14

#!/bin/bash
#SBATCH --job-name=trl
#SBATCH --partition=production-cluster
#SBATCH --gpus-per-task=8
#SBATCH --cpus-per-gpu=10
#SBATCH --ntasks=2 # 1 nodes
#SBATCH --output=slurm/logs/%x_%j.out

srun python jax_test.py
import jax
import os
jax.distributed.initialize()

print(os.environ.get("CUDA_VISIBLE_DEVICES"))

print("jax.device_count()", jax.device_count())
print("jax.local_device_count()", jax.local_device_count())
print("jax.devices()", jax.devices())
0,1,2,3,4,5,6,7
jax.device_count() 2
jax.local_device_count() 1
jax.devices() [StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=1, process_index=1, slice_index=1)]
0,1,2,3,4,5,6,7
jax.device_count() 2
jax.local_device_count() 1
jax.devices() [StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=1, process_index=1, slice_index=1)]

If I set

  @classmethod
  def get_local_process_id(cls) -> Optional[int]:
    return None
0,1,2,3,4,5,6,7
jax.device_count() 16
jax.local_device_count() 8
jax.devices() [StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=1, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=2, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=3, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=4, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=5, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=6, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=7, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=8, process_index=1, slice_index=1), StreamExecutorGpuDevice(id=9, process_index=1, slice_index=1), StreamExecutorGpuDevice(id=10, process_index=1, slice_index=1), StreamExecutorGpuDevice(id=11, process_index=1, slice_index=1), StreamExecutorGpuDevice(id=12, process_index=1, slice_index=1), StreamExecutorGpuDevice(id=13, process_index=1, slice_index=1), StreamExecutorGpuDevice(id=14, process_index=1, slice_index=1), StreamExecutorGpuDevice(id=15, process_index=1, slice_index=1)]
0,1,2,3,4,5,6,7
jax.device_count() 16
jax.local_device_count() 8
jax.devices() [StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=1, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=2, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=3, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=4, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=5, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=6, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=7, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=8, process_index=1, slice_index=1), StreamExecutorGpuDevice(id=9, process_index=1, slice_index=1), StreamExecutorGpuDevice(id=10, process_index=1, slice_index=1), StreamExecutorGpuDevice(id=11, process_index=1, slice_index=1), StreamExecutorGpuDevice(id=12, process_index=1, slice_index=1), StreamExecutorGpuDevice(id=13, process_index=1, slice_index=1), StreamExecutorGpuDevice(id=14, process_index=1, slice_index=1), StreamExecutorGpuDevice(id=15, process_index=1, slice_index=1)]

vwxyzjn avatar Aug 04 '23 14:08 vwxyzjn

There is 2 ways of doing multi-GPU on one node. One process that handle all the GPUS, or each GPU have a different process.

When doing multi-node, we end up doing multiple process. Doing one process per GPU is faster in some cases. So this is the recommended way if you script support it.

The current default suppose that if you use slurm, you will tell slow to start 1 process per GPUs. But this isn't what you are doing. Maybe the code can be updated to detect that automatically. But it isn't the case right now.

So I would suggest to use initialize(). But modify your slurm job to have 1 process for each GPU: --tasks-per-node=2 in the original question.

TODO: Lets keep that issue open to verify if we can update initialize() to automatically detect how many GPUs should be used by the process.

nouiz avatar Aug 04 '23 17:08 nouiz

Ok, so I figured out a quick fix. Just add jax.distributed.initialize(local_device_ids=range(8)). Works like a charm.

import jax
import os
jax.distributed.initialize(local_device_ids=range(8))
print("jax.__version__", jax.__version__)

print(os.environ.get("CUDA_VISIBLE_DEVICES"))

print("jax.device_count()", jax.device_count())
print("jax.local_device_count()", jax.local_device_count())
print("jax.devices()", jax.devices())
jax.__version__ 0.4.13
0,1,2,3,4,5,6,7
jax.device_count() 16
jax.local_device_count() 8
jax.devices() [gpu(id=0), gpu(id=1), gpu(id=2), gpu(id=3), gpu(id=4), gpu(id=5), gpu(id=6), gpu(id=7), gpu(id=8), gpu(id=9), gpu(id=10), gpu(id=11), gpu(id=12), gpu(id=13), gpu(id=14), gpu(id=15)]
jax.__version__ 0.4.13
0,1,2,3,4,5,6,7
jax.device_count() 16
jax.local_device_count() 8
jax.devices() [gpu(id=0), gpu(id=1), gpu(id=2), gpu(id=3), gpu(id=4), gpu(id=5), gpu(id=6), gpu(id=7), gpu(id=8), gpu(id=9), gpu(id=10), gpu(id=11), gpu(id=12), gpu(id=13), gpu(id=14), gpu(id=15)]

vwxyzjn avatar Aug 08 '23 20:08 vwxyzjn

@nouiz You are right, running multiple processes of jax on each host (with each one responsible for just one GPU) is one way to handle this. And indeed with --tasks-per-node=2 jax correctly initializes the two processes to each handle one GPU.

@vwxyzjn Thank you for that idea. Using local_device_ids=range(2) does the exact same thing as the patch I mentioned above, but is of course a lot more elegant (and doesn't require modifying the jax source code).

Findus23 avatar Aug 09 '23 16:08 Findus23

@Findus23 do you have a patch to JAX to have it handle this correctly? If so, where it is? It would be good to update JAX to handle this.

nouiz avatar Aug 09 '23 16:08 nouiz

@nouiz Sorry, by patch I just mean editing https://github.com/google/jax/blob/f94104f71a041def61ea5b22676bbbecbfbe0a9b/jax/_src/clusters/slurm_cluster.py#L60-L62 to instead return None .

But that isn't correct in a general case of course.

I just compared the environment variables in a --tasks-per-node=1 and a --tasks-per-node=2 run and saw that there is a SLURM_NTASKS_PER_NODE. So it could maybe check if that is set to 1 and then pick all GPUs (set get_local_process_id to range(num_gpus)` if nothing else is specified.

But then again maybe that is a bit too much implicit magic and it would be better just to update the documentation to say that both one-gpu-per-process and one-process-per-host are possible and a short example each how to specify them.

Findus23 avatar Aug 10 '23 14:08 Findus23

I've run into this problem as well. But I would like to ask, when running distributed computing using jax, which is more elegant, assigning one process to a node to manage all the local devices or assigning multiple processes to a node, each managing one local device?

tulvgengenr avatar Apr 18 '24 09:04 tulvgengenr

Ok, so I figured out a quick fix. Just add jax.distributed.initialize(local_device_ids=range(8)). Works like a charm.

import jax
import os
jax.distributed.initialize(local_device_ids=range(8))
print("jax.__version__", jax.__version__)

print(os.environ.get("CUDA_VISIBLE_DEVICES"))

print("jax.device_count()", jax.device_count())
print("jax.local_device_count()", jax.local_device_count())
print("jax.devices()", jax.devices())

Thanks for the solutions, I was having the same issues. To automate, why not do the following? This should work for both cases where there is only one device per host and for multiple GPUs per task.

import os
import jax

cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
local_device_ids = [int(i) for i in cuda_visible_devices.split(",")]
print(local_device_ids)

jax.distributed.initialize(local_device_ids=local_device_ids)
print("jax.__version__", jax.__version__)
print(jax.device_count())
print(jax.local_device_count())

for a job with the following options,

#SBATCH --nodes 2
#SBATCH --ntasks-per-node 1
#SBATCH --gpus-per-task 2

outputs,

[0, 1]
jax.__version__ 0.4.30
4
2
[0, 1]
jax.__version__ 0.4.30
4
2

alexlyttle avatar Jul 02 '24 13:07 alexlyttle