deepmind-research
deepmind-research copied to clipboard
Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
how can i solve this problem?
From: https://pypi.org/project/jax/
pip install --upgrade pip
CUDA 12 installation
- Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
CUDA 11 installation
- Note: wheels only available on linux.
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Try work
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
Hi, I tried your way and installed jax gpu version using pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html, and ran the below codes,
from jax.lib import xla_bridge print(xla_bridge.get_backend().platform), it is showing,
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) cpu
my Gpu specification is a follows, +-----------------------------------------------------------------------------+ | NVIDIA-SMI 525.125.06 Driver Version: 525.125.06 CUDA Version: 12.0 | |-------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |===============================+======================+======================| | 0 NVIDIA GeForce ... Off | 00000000:01:00.0 On | Off | | 0% 49C P8 24W / 450W | 376MiB / 24564MiB | 0% Default | | | | N/A | +-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=============================================================================| | 0 N/A N/A 2020 G /usr/lib/xorg/Xorg 140MiB | | 0 N/A N/A 2151 G /usr/bin/gnome-shell 83MiB | | 0 N/A N/A 3362 G ...0/usr/lib/firefox/firefox 150MiB | +-----------------------------------------------------------------------------+ and I am using python 3.11.0 version I also Installed, some additional dependecies like, numpy>=1.16.4 jax>=0.2.6 jaxlib>=0.1.69 flax>=0.2.2 opencv-python>=4.4.0 Pillow>=7.2.0 pyyaml>=5.3.1 scipy>=1.4.1 tensorboard>=2.4.0 tensorflow>=2.3.1 tensorflow-hub>=0.11.0
why my jax is not detecting the gpu/tpu, i am running on ubuntu 22.04 version?
thank you, this worked
Though I had to first go through all this
CUDA Toolkit: I ran it with 11.8
cuDNN: I ran it with 8.9.23.28
TensortRT: U ran it with version for 11.8
Then
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html