deepmind-research icon indicating copy to clipboard operation
deepmind-research copied to clipboard

Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'

Open caihongch opened this issue 2 years ago • 3 comments
trafficstars

Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'

how can i solve this problem?

caihongch avatar Apr 01 '23 15:04 caihongch

From: https://pypi.org/project/jax/

pip install --upgrade pip

CUDA 12 installation

  • Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

CUDA 11 installation

  • Note: wheels only available on linux.
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Try work

from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

sumowi avatar Jun 05 '23 20:06 sumowi

Hi, I tried your way and installed jax gpu version using pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html, and ran the below codes,

from jax.lib import xla_bridge print(xla_bridge.get_backend().platform), it is showing,

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) cpu

my Gpu specification is a follows, +-----------------------------------------------------------------------------+ | NVIDIA-SMI 525.125.06 Driver Version: 525.125.06 CUDA Version: 12.0 | |-------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |===============================+======================+======================| | 0 NVIDIA GeForce ... Off | 00000000:01:00.0 On | Off | | 0% 49C P8 24W / 450W | 376MiB / 24564MiB | 0% Default | | | | N/A | +-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=============================================================================| | 0 N/A N/A 2020 G /usr/lib/xorg/Xorg 140MiB | | 0 N/A N/A 2151 G /usr/bin/gnome-shell 83MiB | | 0 N/A N/A 3362 G ...0/usr/lib/firefox/firefox 150MiB | +-----------------------------------------------------------------------------+ and I am using python 3.11.0 version I also Installed, some additional dependecies like, numpy>=1.16.4 jax>=0.2.6 jaxlib>=0.1.69 flax>=0.2.2 opencv-python>=4.4.0 Pillow>=7.2.0 pyyaml>=5.3.1 scipy>=1.4.1 tensorboard>=2.4.0 tensorflow>=2.3.1 tensorflow-hub>=0.11.0

why my jax is not detecting the gpu/tpu, i am running on ubuntu 22.04 version?

Dharmendra04 avatar Jul 15 '23 19:07 Dharmendra04

thank you, this worked Though I had to first go through all this CUDA Toolkit: I ran it with 11.8 cuDNN: I ran it with 8.9.23.28 TensortRT: U ran it with version for 11.8 Then pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

gmacgmac avatar Jul 31 '23 14:07 gmacgmac