jax icon indicating copy to clipboard operation
jax copied to clipboard

Could not create cudnn handle: CUDNN_STATUS_NOT_INITIALIZED

Open carlosgmartin opened this issue 1 year ago • 2 comments

Description

bash-4.2$ python3 -c "import jax; jax.numpy.zeros(0)"
2024-02-12 20:36:44.949913: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:474] Could not create cudnn handle: CUDNN_STATUS_NOT_INITIALIZED
2024-02-12 20:36:44.950054: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:478] Memory usage: 42069065728 bytes free, 42505273344 bytes total.
2024-02-12 20:36:44.950107: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:488] Possibly insufficient driver version: 470.182.3
2024-02-12 20:36:44.950492: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:474] Could not create cudnn handle: CUDNN_STATUS_NOT_INITIALIZED
2024-02-12 20:36:44.950523: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:478] Memory usage: 42069065728 bytes free, 42505273344 bytes total.
2024-02-12 20:36:44.950551: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:488] Possibly insufficient driver version: 470.182.3
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/marvel/home/cgmartin/miniforge3/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 2317, in zeros
    return lax.full(shape, 0, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device))
  File "/marvel/home/cgmartin/miniforge3/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 1226, in full
    fill_value = _convert_element_type(fill_value, dtype, weak_type)
  File "/marvel/home/cgmartin/miniforge3/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 560, in _convert_element_type
    return convert_element_type_p.bind(operand, new_dtype=new_dtype,
  File "/marvel/home/cgmartin/miniforge3/lib/python3.10/site-packages/jax/_src/core.py", line 444, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/marvel/home/cgmartin/miniforge3/lib/python3.10/site-packages/jax/_src/core.py", line 447, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/marvel/home/cgmartin/miniforge3/lib/python3.10/site-packages/jax/_src/core.py", line 935, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/marvel/home/cgmartin/miniforge3/lib/python3.10/site-packages/jax/_src/dispatch.py", line 87, in apply_primitive
    outs = fun(*args)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

What jax/jaxlib version are you using?

0.4.24 0.4.24

Which accelerator(s) are you using?

GPU

Additional system info?

1.26.3 3.10.12 | packaged by conda-forge | (main, Jun 23 2023, 22:40:32) [GCC 12.3.0] uname_result(system='Linux', node='marvel-1-13', release='3.10.0-957.1.3.el7.x86_64', version='#1 SMP Thu Nov 29 14:49:43 UTC 2018', machine='x86_64')

NVIDIA GPU info

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.182.03   Driver Version: 470.182.03   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A100-SXM...  On   | 00000000:07:00.0 Off |                    0 |
| N/A   28C    P0    53W / 400W |      0MiB / 40536MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

carlosgmartin avatar Feb 13 '24 01:02 carlosgmartin

Which CUDA version of JAX did you install?

hawkinsp avatar Feb 13 '24 02:02 hawkinsp

@hawkinsp

python3 -m pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

carlosgmartin avatar Feb 13 '24 03:02 carlosgmartin

Have you solved this issue? I'm still facing this problem.

pengzhi1998 avatar Apr 10 '24 00:04 pengzhi1998

@pengzhi1998 I did a clean reinstall of the CUDA driver and toolkit at some point. That seems to have solved the issue.

carlosgmartin avatar Apr 11 '24 04:04 carlosgmartin