Multiple GPU devices broken?
The following code was working last time I have checked, but not now using the pip wheels
This is on an Ubuntu workstation with 4 A100 GPUs, using python 3.9, cuda 11.2 (I have tried many different 11.X CUDA versions, same error), driver version 495.29.05
$ pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html
....
$ cat test.py
import jax
import cupy as cp
import jax.numpy as jnp
import numpy as np
for d in range(cp.cuda.runtime.getDeviceCount()):
cp.cuda.runtime.setDevice(d)
a = np.random.rand(4)
x = cp.array(a)
print(a,x,x.device)
for d in jax.devices():
a = np.random.rand(4)
x = jax.device_put(jnp.array(a),d)
print(a,x,d)
$ python test.py
[0.36026366 0.49166635 0.02453196 0.06113282] [0.36026366 0.49166635 0.02453196 0.06113282] <CUDA Device 0>
[0.58960801 0.87522834 0.06568807 0.91739894] [0.58960801 0.87522834 0.06568807 0.91739894] <CUDA Device 1>
[0.62792459 0.31616489 0.75599587 0.13468365] [0.62792459 0.31616489 0.75599587 0.13468365] <CUDA Device 2>
[0.63147776 0.40358606 0.82040093 0.95835464] [0.63147776 0.40358606 0.82040093 0.95835464] <CUDA Device 3>
[0.64609584 0.90184385 0.18761828 0.34506156] [0.6460959 0.90184385 0.18761827 0.34506157] gpu:0
[0.89907119 0.94931636 0.24913002 0.63580519] [0. 0. 0. 0.] gpu:1
[0.08435405 0.37044366 0.45964628 0.7886945 ] [0.08435405 0.37044367 0.45964628 0.7886945 ] gpu:2
[0.83871242 0.30987635 0.31794099 0.64514956] [0.8387124 0.30987635 0.31794098 0.6451495 ] gpu:3
Notice that gpu:1 device array contains only zeros.
Building from source using tag jax-v0.2.9 works fine. Am I doing something wrong or is it a bug?
$ git log -1
commit 10947403202af72299bf19feba40d0d6f7ecb350 (HEAD, tag: jax-v0.2.9)
Merge: a9e5fabe d2ae949b
Author: jax authors <[email protected]>
Date: Tue Jan 26 20:47:13 2021 -0800
Merge pull request #5528 from google:update-pypi
PiperOrigin-RevId: 354012572
$ pip install dist/jaxlib-0.1.60-cp39-none-manylinux2010_x86_64.whl
Processing ./dist/jaxlib-0.1.60-cp39-none-manylinux2010_x86_64.whl
Requirement already satisfied: numpy>=1.12 in /home/zampins/Devel/miniforge/envs/testjax/lib/python3.9/site-packages (from jaxlib==0.1.60) (1.22.2)
Requirement already satisfied: flatbuffers in /home/zampins/Devel/miniforge/envs/testjax/lib/python3.9/site-packages (from jaxlib==0.1.60) (2.0)
Requirement already satisfied: scipy in /home/zampins/Devel/miniforge/envs/testjax/lib/python3.9/site-packages (from jaxlib==0.1.60) (1.8.0)
Requirement already satisfied: absl-py in /home/zampins/Devel/miniforge/envs/testjax/lib/python3.9/site-packages (from jaxlib==0.1.60) (1.0.0)
Requirement already satisfied: six in /home/zampins/Devel/miniforge/envs/testjax/lib/python3.9/site-packages (from absl-py->jaxlib==0.1.60) (1.16.0)
Installing collected packages: jaxlib
Attempting uninstall: jaxlib
Found existing installation: jaxlib 0.3.0+cuda11.cudnn82
Uninstalling jaxlib-0.3.0+cuda11.cudnn82:
Successfully uninstalled jaxlib-0.3.0+cuda11.cudnn82
Successfully installed jaxlib-0.1.60
$ pip install -e .
Obtaining file:///home/zampins/sandbox/jax
Preparing metadata (setup.py) ... done
Requirement already satisfied: numpy>=1.12 in /home/zampins/Devel/miniforge/envs/testjax/lib/python3.9/site-packages (from jax==0.2.9) (1.22.2)
Requirement already satisfied: absl-py in /home/zampins/Devel/miniforge/envs/testjax/lib/python3.9/site-packages (from jax==0.2.9) (1.0.0)
Requirement already satisfied: opt_einsum in /home/zampins/Devel/miniforge/envs/testjax/lib/python3.9/site-packages (from jax==0.2.9) (3.3.0)
Requirement already satisfied: six in /home/zampins/Devel/miniforge/envs/testjax/lib/python3.9/site-packages (from absl-py->jax==0.2.9) (1.16.0)
Installing collected packages: jax
Attempting uninstall: jax
Found existing installation: jax 0.3.1
Uninstalling jax-0.3.1:
Successfully uninstalled jax-0.3.1
Running setup.py develop for jax
Successfully installed jax-0.2.9
$ python test.py
[0.72413586 0.03981569 0.82808771 0.08488717] [0.72413586 0.03981569 0.82808771 0.08488717] <CUDA Device 0>
[0.20310118 0.20229541 0.17232789 0.33406987] [0.20310118 0.20229541 0.17232789 0.33406987] <CUDA Device 1>
[0.38727162 0.83158815 0.74289441 0.16270889] [0.38727162 0.83158815 0.74289441 0.16270889] <CUDA Device 2>
[0.61987612 0.10876816 0.33720292 0.88087605] [0.61987612 0.10876816 0.33720292 0.88087605] <CUDA Device 3>
[0.22083954 0.78400095 0.12021694 0.49609661] [0.22083955 0.78400093 0.12021694 0.4960966 ] gpu:0
[0.32103459 0.01537868 0.96763227 0.10868705] [0.32103458 0.01537868 0.9676323 0.10868705] gpu:1
[0.3957342 0.86166002 0.87006812 0.80167636] [0.3957342 0.86166 0.87006813 0.8016764 ] gpu:2
[0.72442095 0.83735143 0.3064454 0.69193466] [0.72442096 0.83735144 0.3064454 0.69193465] gpu:3
I'm unable to reproduce this. I created a GCP VM with 4xA100 GPUs, installed CUDA toolkit 11.2 and driver 495.44. I installed Python 3.9 and packages cupy-cuda112 v10.2.0, jax v0.3.1, and jaxlib v0.3.0.
I get:
[0.60967659 0.99469146 0.42833437 0.92369786] [0.60967659 0.99469146 0.42833437 0.92369786] <CUDA Device 0>
[0.90148377 0.2266339 0.09378169 0.68339479] [0.90148377 0.2266339 0.09378169 0.68339479] <CUDA Device 1>
[0.08509242 0.02069798 0.56742285 0.62548422] [0.08509242 0.02069798 0.56742285 0.62548422] <CUDA Device 2>
[0.00875828 0.72183364 0.4777141 0.8376914 ] [0.00875828 0.72183364 0.4777141 0.8376914 ] <CUDA Device 3>
[0.82617048 0.02544475 0.54643491 0.67557591] [0.8261705 0.02544475 0.54643494 0.6755759 ] gpu:0
[0.41064849 0.76792539 0.2629806 0.62022472] [0.4106485 0.7679254 0.2629806 0.6202247] gpu:1
[0.59699067 0.56032479 0.23362261 0.47442987] [0.59699064 0.5603248 0.23362261 0.47442988] gpu:2
[0.77527778 0.36935768 0.09267553 0.23960158] [0.7752778 0.36935768 0.09267553 0.23960158] gpu:3
Without a way to reproduce the problem, I'm not going to be able to debug it. Are you by any chance able to give complete instructions to reproduce it in a cloud VM of some sort? That way, I could follow them to reproduce.
Note v0.2.9 of jax is pretty old, there are dozens of releases between 0.2.9 and 0.3.0. (The third digit is not a decimal value).
Can you share the output of nvidia-smi topo -m ?
$ nvidia-smi topo -m
GPU0 GPU1 GPU2 GPU3 mlx5_0 mlx5_1 mlx5_2 mlx5_3 CPU Affinity NUMA Affinity
GPU0 X PXB SYS SYS SYS SYS SYS SYS 0-27 0
GPU1 PXB X SYS SYS SYS SYS SYS SYS 0-27 0
GPU2 SYS SYS X PXB PXB PXB PXB PXB 28-55 1
GPU3 SYS SYS PXB X PXB PXB PXB PXB 28-55 1
mlx5_0 SYS SYS PXB PXB X PIX PIX PIX
mlx5_1 SYS SYS PXB PXB PIX X PIX PIX
mlx5_2 SYS SYS PXB PXB PIX PIX X PIX
mlx5_3 SYS SYS PXB PXB PIX PIX PIX X
Legend:
X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks
Was this issue resolved? @stefanozampini
Closing since no activity, feel free to open again with more info!
Edit: For such issues, it's a good idea to check the correctness of GPUs communication with nccl-tests