jax
jax copied to clipboard
Could not create cudnn handle: CUDNN_STATUS_NOT_INITIALIZED
Description
bash-4.2$ python3 -c "import jax; jax.numpy.zeros(0)"
2024-02-12 20:36:44.949913: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:474] Could not create cudnn handle: CUDNN_STATUS_NOT_INITIALIZED
2024-02-12 20:36:44.950054: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:478] Memory usage: 42069065728 bytes free, 42505273344 bytes total.
2024-02-12 20:36:44.950107: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:488] Possibly insufficient driver version: 470.182.3
2024-02-12 20:36:44.950492: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:474] Could not create cudnn handle: CUDNN_STATUS_NOT_INITIALIZED
2024-02-12 20:36:44.950523: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:478] Memory usage: 42069065728 bytes free, 42505273344 bytes total.
2024-02-12 20:36:44.950551: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:488] Possibly insufficient driver version: 470.182.3
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 1, in <module>
File "/marvel/home/cgmartin/miniforge3/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 2317, in zeros
return lax.full(shape, 0, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device))
File "/marvel/home/cgmartin/miniforge3/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 1226, in full
fill_value = _convert_element_type(fill_value, dtype, weak_type)
File "/marvel/home/cgmartin/miniforge3/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 560, in _convert_element_type
return convert_element_type_p.bind(operand, new_dtype=new_dtype,
File "/marvel/home/cgmartin/miniforge3/lib/python3.10/site-packages/jax/_src/core.py", line 444, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/marvel/home/cgmartin/miniforge3/lib/python3.10/site-packages/jax/_src/core.py", line 447, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/marvel/home/cgmartin/miniforge3/lib/python3.10/site-packages/jax/_src/core.py", line 935, in process_primitive
return primitive.impl(*tracers, **params)
File "/marvel/home/cgmartin/miniforge3/lib/python3.10/site-packages/jax/_src/dispatch.py", line 87, in apply_primitive
outs = fun(*args)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
What jax/jaxlib version are you using?
0.4.24 0.4.24
Which accelerator(s) are you using?
GPU
Additional system info?
1.26.3 3.10.12 | packaged by conda-forge | (main, Jun 23 2023, 22:40:32) [GCC 12.3.0] uname_result(system='Linux', node='marvel-1-13', release='3.10.0-957.1.3.el7.x86_64', version='#1 SMP Thu Nov 29 14:49:43 UTC 2018', machine='x86_64')
NVIDIA GPU info
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.182.03 Driver Version: 470.182.03 CUDA Version: 11.4 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA A100-SXM... On | 00000000:07:00.0 Off | 0 |
| N/A 28C P0 53W / 400W | 0MiB / 40536MiB | 0% Default |
| | | Disabled |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+
Which CUDA version of JAX did you install?
@hawkinsp
python3 -m pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Have you solved this issue? I'm still facing this problem.
@pengzhi1998 I did a clean reinstall of the CUDA driver and toolkit at some point. That seems to have solved the issue.