jax icon indicating copy to clipboard operation
jax copied to clipboard

Using GPUs with Jax

Open akhilnadigatla opened this issue 3 years ago • 5 comments

Discussed in https://github.com/google/jax/discussions/9858

Originally posted by akhilnadigatla March 11, 2022 Hello!

I am trying to get Jax running on my GPUs, but face these error messages:

I0311 16:52:36.027944 140563802497728 xla_bridge.py:247] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
2022-03-11 16:52:36.058206: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
I0311 16:52:36.058603 140563802497728 xla_bridge.py:247] Unable to initialize backend 'gpu': FAILED_PRECONDITION: No visible GPU devices.
I0311 16:52:36.059262 140563802497728 xla_bridge.py:247] Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
W0311 16:52:36.059466 140563802497728 xla_bridge.py:252] No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

I have looked for solutions to the problems, and made sure all drivers/CUDA version/cuDNN version are as expected. I am working on Ubuntu 18.04 and the outputs of some key commands are:

>> nvidia-smi
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.103.01   Driver Version: 470.103.01   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Quadro RTX 8000     Off  | 00000000:15:00.0 Off |                  Off |
| 61%   80C    P2   248W / 260W |  32490MiB / 48601MiB |    100%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Quadro RTX 8000     Off  | 00000000:2D:00.0 Off |                  Off |
| 76%   86C    P2   261W / 260W |   7237MiB / 48600MiB |     99%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
>> nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2021 NVIDIA Corporation
Built on Wed_Jun__2_19:15:15_PDT_2021
Cuda compilation tools, release 11.4, V11.4.48
Build cuda_11.4.r11.4/compiler.30033411_0
>> cat /usr/local/cuda/include/cudnn_version.h
...
/**
 * \file: The master cuDNN version file.
 */

#ifndef CUDNN_VERSION_H_
#define CUDNN_VERSION_H_

#define CUDNN_MAJOR 8
#define CUDNN_MINOR 2
#define CUDNN_PATCHLEVEL 4

#define CUDNN_VERSION (CUDNN_MAJOR * 1000 + CUDNN_MINOR * 100 + CUDNN_PATCHLEVEL)

#endif /* CUDNN_VERSION_H */

All the path variables (to my knowledge) have been set as expected:

export PATH=/usr/local/cuda-11.4/bin:${PATH}
export LD_LIBRARY_PATH=/usr/local/cuda/lib64:${LD_LIBRARY_PATH}
export LD_LIBRARY_PATH=/usr/local/cuda-11.4/lib64:${LD_LIBRARY_PATH}
export CUDA_HOME=/usr/local/cuda

And Tensorflow does appear to recognize the existence of two GPUs:

>> python -c "import tensorflow as tf; print(tf.config.list_physical_devices('GPU'))"
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU')]

I also have the right version of jaxlib installed:

>> pip freeze | grep jaxlib
jaxlib==0.3.0+cuda11.cudnn82

I am not sure what could be the issue here. Any help would be greatly appreciated! Thank you.

akhilnadigatla avatar Mar 11 '22 22:03 akhilnadigatla

Hi! I am also having this exact same problem, any progress on solutions to this?

blakelash avatar Mar 15 '22 02:03 blakelash

Unfortunately, no. :(

akhilnadigatla avatar Mar 15 '22 03:03 akhilnadigatla

It seems there is a problem with tensorflow regarding the way it looks to your gpu, which version of tensorflow do you have and do you get a problem when only using tensorflow and not jax?

gabrielraya avatar Mar 30 '22 10:03 gabrielraya

I have tensorflow version 2.8.0, and I do not have any problems running workloads on TensorFlow (which does seem to recognize my GPUs).

Running the following returns Num GPUs Available: 2 as expected:

import tensorflow as tf
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

akhilnadigatla avatar Mar 31 '22 15:03 akhilnadigatla

Was this issue resolved? @akhilnadigatla

sudhakarsingh27 avatar Aug 08 '22 20:08 sudhakarsingh27

Hi, I'm running into the same issue. Are there any solutions to resolve this?

aasthajh avatar Sep 03 '22 03:09 aasthajh

I'm closing this issue because we don't have a way to reproduce it and it relates to an old jax version. Please reopen it (or file a new issue) if you can still reproduce the problem with an up to date jax.

hawkinsp avatar Jun 21 '23 14:06 hawkinsp

I still have similar error when my machine (ubuntu22+cuda12.0+jax0.4.16) wakes up after sleep.

I0000 00:00:1695473905.493231    7204 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.                                                                                                                                                                                   
2023-09-23 21:58:25.538852: W external/xla/xla/service/platform_util.cc:198] unable to create StreamExecutor for CUDA:0: failed initializing StreamExecutor for CUDA device ordinal 0: INTERNAL: failed call to cuDevicePrimaryCtxRetain: CUDA_ERROR_UNKNOWN: unknown error 
I0923 21:58:25.539228 139973430083584 xla_bridge.py:513] Unable to initialize backend 'cuda': INTERNAL: no supported devices found for platform CUDA                                                                                                                        

klknn avatar Sep 23 '23 13:09 klknn

@hawkinsp Hi Peter We are trying to run fit_all.sh https://github.com/martiningram/mcmc_runtime_comparison?tab=readme-ov-file [ on an ubuntu 20.04 GPU instance that has single GPU device; the execution throws the same error as in the original description

PyMC JAX CPU vectorized
WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.
Compiling...
2024-01-07 20:42:07.355217: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
CUDA backend failed to initialize: FAILED_PRECONDITION: No visible GPU devices. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

if I just retrieve from jax about gpu it would list one gpu[0] do you need a new issue created in ordder to render a helping hand? Thank you

AndreV84 avatar Jan 07 '24 20:01 AndreV84