jax icon indicating copy to clipboard operation
jax copied to clipboard

Multiple GPU devices broken?

Open stefanozampini opened this issue 3 years ago • 4 comments

The following code was working last time I have checked, but not now using the pip wheels This is on an Ubuntu workstation with 4 A100 GPUs, using python 3.9, cuda 11.2 (I have tried many different 11.X CUDA versions, same error), driver version 495.29.05

$ pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html
....

$ cat test.py 
import jax
import cupy as cp
import jax.numpy as jnp
import numpy as np

for d in range(cp.cuda.runtime.getDeviceCount()):
  cp.cuda.runtime.setDevice(d)
  a = np.random.rand(4)
  x = cp.array(a)
  print(a,x,x.device)

for d in jax.devices():
  a = np.random.rand(4)
  x = jax.device_put(jnp.array(a),d)
  print(a,x,d)

$ python test.py 
[0.36026366 0.49166635 0.02453196 0.06113282] [0.36026366 0.49166635 0.02453196 0.06113282] <CUDA Device 0>
[0.58960801 0.87522834 0.06568807 0.91739894] [0.58960801 0.87522834 0.06568807 0.91739894] <CUDA Device 1>
[0.62792459 0.31616489 0.75599587 0.13468365] [0.62792459 0.31616489 0.75599587 0.13468365] <CUDA Device 2>
[0.63147776 0.40358606 0.82040093 0.95835464] [0.63147776 0.40358606 0.82040093 0.95835464] <CUDA Device 3>
[0.64609584 0.90184385 0.18761828 0.34506156] [0.6460959  0.90184385 0.18761827 0.34506157] gpu:0
[0.89907119 0.94931636 0.24913002 0.63580519] [0. 0. 0. 0.] gpu:1
[0.08435405 0.37044366 0.45964628 0.7886945 ] [0.08435405 0.37044367 0.45964628 0.7886945 ] gpu:2
[0.83871242 0.30987635 0.31794099 0.64514956] [0.8387124  0.30987635 0.31794098 0.6451495 ] gpu:3

Notice that gpu:1 device array contains only zeros.

Building from source using tag jax-v0.2.9 works fine. Am I doing something wrong or is it a bug?

$ git log -1
commit 10947403202af72299bf19feba40d0d6f7ecb350 (HEAD, tag: jax-v0.2.9)
Merge: a9e5fabe d2ae949b
Author: jax authors <[email protected]>
Date:   Tue Jan 26 20:47:13 2021 -0800

    Merge pull request #5528 from google:update-pypi
    
    PiperOrigin-RevId: 354012572

$ pip install dist/jaxlib-0.1.60-cp39-none-manylinux2010_x86_64.whl 
Processing ./dist/jaxlib-0.1.60-cp39-none-manylinux2010_x86_64.whl
Requirement already satisfied: numpy>=1.12 in /home/zampins/Devel/miniforge/envs/testjax/lib/python3.9/site-packages (from jaxlib==0.1.60) (1.22.2)
Requirement already satisfied: flatbuffers in /home/zampins/Devel/miniforge/envs/testjax/lib/python3.9/site-packages (from jaxlib==0.1.60) (2.0)
Requirement already satisfied: scipy in /home/zampins/Devel/miniforge/envs/testjax/lib/python3.9/site-packages (from jaxlib==0.1.60) (1.8.0)
Requirement already satisfied: absl-py in /home/zampins/Devel/miniforge/envs/testjax/lib/python3.9/site-packages (from jaxlib==0.1.60) (1.0.0)
Requirement already satisfied: six in /home/zampins/Devel/miniforge/envs/testjax/lib/python3.9/site-packages (from absl-py->jaxlib==0.1.60) (1.16.0)
Installing collected packages: jaxlib
  Attempting uninstall: jaxlib
    Found existing installation: jaxlib 0.3.0+cuda11.cudnn82
    Uninstalling jaxlib-0.3.0+cuda11.cudnn82:
      Successfully uninstalled jaxlib-0.3.0+cuda11.cudnn82
Successfully installed jaxlib-0.1.60

$ pip install -e .
Obtaining file:///home/zampins/sandbox/jax
  Preparing metadata (setup.py) ... done
Requirement already satisfied: numpy>=1.12 in /home/zampins/Devel/miniforge/envs/testjax/lib/python3.9/site-packages (from jax==0.2.9) (1.22.2)
Requirement already satisfied: absl-py in /home/zampins/Devel/miniforge/envs/testjax/lib/python3.9/site-packages (from jax==0.2.9) (1.0.0)
Requirement already satisfied: opt_einsum in /home/zampins/Devel/miniforge/envs/testjax/lib/python3.9/site-packages (from jax==0.2.9) (3.3.0)
Requirement already satisfied: six in /home/zampins/Devel/miniforge/envs/testjax/lib/python3.9/site-packages (from absl-py->jax==0.2.9) (1.16.0)
Installing collected packages: jax
  Attempting uninstall: jax
    Found existing installation: jax 0.3.1
    Uninstalling jax-0.3.1:
      Successfully uninstalled jax-0.3.1
  Running setup.py develop for jax
Successfully installed jax-0.2.9

$ python test.py 
[0.72413586 0.03981569 0.82808771 0.08488717] [0.72413586 0.03981569 0.82808771 0.08488717] <CUDA Device 0>
[0.20310118 0.20229541 0.17232789 0.33406987] [0.20310118 0.20229541 0.17232789 0.33406987] <CUDA Device 1>
[0.38727162 0.83158815 0.74289441 0.16270889] [0.38727162 0.83158815 0.74289441 0.16270889] <CUDA Device 2>
[0.61987612 0.10876816 0.33720292 0.88087605] [0.61987612 0.10876816 0.33720292 0.88087605] <CUDA Device 3>
[0.22083954 0.78400095 0.12021694 0.49609661] [0.22083955 0.78400093 0.12021694 0.4960966 ] gpu:0
[0.32103459 0.01537868 0.96763227 0.10868705] [0.32103458 0.01537868 0.9676323  0.10868705] gpu:1
[0.3957342  0.86166002 0.87006812 0.80167636] [0.3957342  0.86166    0.87006813 0.8016764 ] gpu:2
[0.72442095 0.83735143 0.3064454  0.69193466] [0.72442096 0.83735144 0.3064454  0.69193465] gpu:3

stefanozampini avatar Mar 04 '22 15:03 stefanozampini

I'm unable to reproduce this. I created a GCP VM with 4xA100 GPUs, installed CUDA toolkit 11.2 and driver 495.44. I installed Python 3.9 and packages cupy-cuda112 v10.2.0, jax v0.3.1, and jaxlib v0.3.0.

I get:

[0.60967659 0.99469146 0.42833437 0.92369786] [0.60967659 0.99469146 0.42833437 0.92369786] <CUDA Device 0>
[0.90148377 0.2266339  0.09378169 0.68339479] [0.90148377 0.2266339  0.09378169 0.68339479] <CUDA Device 1>
[0.08509242 0.02069798 0.56742285 0.62548422] [0.08509242 0.02069798 0.56742285 0.62548422] <CUDA Device 2>
[0.00875828 0.72183364 0.4777141  0.8376914 ] [0.00875828 0.72183364 0.4777141  0.8376914 ] <CUDA Device 3>
[0.82617048 0.02544475 0.54643491 0.67557591] [0.8261705  0.02544475 0.54643494 0.6755759 ] gpu:0
[0.41064849 0.76792539 0.2629806  0.62022472] [0.4106485 0.7679254 0.2629806 0.6202247] gpu:1
[0.59699067 0.56032479 0.23362261 0.47442987] [0.59699064 0.5603248  0.23362261 0.47442988] gpu:2
[0.77527778 0.36935768 0.09267553 0.23960158] [0.7752778  0.36935768 0.09267553 0.23960158] gpu:3

Without a way to reproduce the problem, I'm not going to be able to debug it. Are you by any chance able to give complete instructions to reproduce it in a cloud VM of some sort? That way, I could follow them to reproduce.

Note v0.2.9 of jax is pretty old, there are dozens of releases between 0.2.9 and 0.3.0. (The third digit is not a decimal value).

hawkinsp avatar Mar 04 '22 16:03 hawkinsp

Can you share the output of nvidia-smi topo -m ?

hawkinsp avatar Mar 07 '22 18:03 hawkinsp

$ nvidia-smi topo -m
	GPU0	GPU1	GPU2	GPU3	mlx5_0	mlx5_1	mlx5_2	mlx5_3	CPU Affinity	NUMA Affinity
GPU0	 X 	PXB	SYS	SYS	SYS	SYS	SYS	SYS	0-27	0
GPU1	PXB	 X 	SYS	SYS	SYS	SYS	SYS	SYS	0-27	0
GPU2	SYS	SYS	 X 	PXB	PXB	PXB	PXB	PXB	28-55	1
GPU3	SYS	SYS	PXB	 X 	PXB	PXB	PXB	PXB	28-55	1
mlx5_0	SYS	SYS	PXB	PXB	 X 	PIX	PIX	PIX		
mlx5_1	SYS	SYS	PXB	PXB	PIX	 X 	PIX	PIX		
mlx5_2	SYS	SYS	PXB	PXB	PIX	PIX	 X 	PIX		
mlx5_3	SYS	SYS	PXB	PXB	PIX	PIX	PIX	 X 		

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

stefanozampini avatar Mar 07 '22 18:03 stefanozampini

Was this issue resolved? @stefanozampini

sudhakarsingh27 avatar Aug 08 '22 20:08 sudhakarsingh27

Closing since no activity, feel free to open again with more info!


Edit: For such issues, it's a good idea to check the correctness of GPUs communication with nccl-tests

sudhakarsingh27 avatar Aug 24 '22 19:08 sudhakarsingh27