jax library error jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed
Description
Hey, I am working on a code that uses the Jax library, and I run into this error over and over again no matter how I tried to configure my environment:
2024-08-20 16:26:58.037892: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:439]
Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2024-08-20 16:26:58.037952: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:443] Memory usage: 46637514752 bytes free, 47587131392 bytes total.
Traceback (most recent call last):
File "GPU_pairwise_pipline.py", line 260, in <module>
SSMD_res_with_indices = process_blocks(train_set_sick, train_set_healthy, block_size)
File "GPU_pairwise_pipline.py", line 172, in process_blocks
mean_block1_sick, var_block1_sick = cal_mean_and_var(block1_sick)
File "GPU_pairwise_pipline.py", line 17, in cal_mean_and_var
data_jax = jnp.array(data)
File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 2035, in array
out = _array_copy(object) if copy else object
File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 4447, in _array_copy
return copy_p.bind(arr)
File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/core.py", line 380, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/core.py", line 383, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/core.py", line 815, in process_primitive
return primitive.impl(*tracers, **params)
File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 4486, in _copy_impl
return dispatch.apply_primitive(prim, *args, **kwargs)
File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/dispatch.py", line 132, in apply_primitive
compiled_fun = xla_primitive_callable(
File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/util.py", line 284, in wrapper
return cached(config._trace_context(), *args, **kwargs)
File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/util.py", line 277, in cached
return f(*args, **kwargs)
File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/dispatch.py", line 223, in xla_primitive_callable
compiled = _xla_callable_uncached(
File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/dispatch.py", line 253, in _xla_callable_uncached
return computation.compile().unsafe_call
File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2323, in compile
executable = UnloadedMeshExecutable.from_hlo(
File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2645, in from_hlo
xla_executable, compile_options = _cached_compilation(
File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2555, in _cached_compilation
xla_executable = dispatch.compile_or_get_cached(
File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/dispatch.py", line 497, in compile_or_get_cached
return backend_compile(backend, computation, compile_options,
File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/dispatch.py", line 465, in backend_compile
return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
There is no memory problem, and I have set up my $LD_LIBRARY_PATH to point to where my CUDNN version I downloaded is:
echo $LD_LIBRARY_PATH
/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/
#define CUDNN_MAJOR 8
#define CUDNN_MINOR 9
Please I would appreciate someone's help in understanding why I get the same error over and over again..
Thank you!
System info (python version, jaxlib version, accelerator, etc.)
nvidia-smi
Tue Aug 20 16:52:27 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08 Driver Version: 545.23.08 CUDA Version: 12.3 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA L40 On | 00000000:01:00.0 Off | 0 |
| N/A 39C P0 79W / 300W | 894MiB / 46068MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 1 NVIDIA L40 On | 00000000:02:00.0 Off | 0 |
| N/A 29C P8 34W / 300W | 21MiB / 46068MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 2 NVIDIA L40 On | 00000000:61:00.0 Off | 0 |
| N/A 29C P8 34W / 300W | 21MiB / 46068MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 3 NVIDIA L40 On | 00000000:62:00.0 Off | 0 |
| N/A 30C P8 35W / 300W | 21MiB / 46068MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| 0 N/A N/A 3774 G /usr/lib/xorg/Xorg 4MiB |
| 0 N/A N/A 999897 C ...physics/rtk243/anaconda3/bin/python 868MiB |
| 1 N/A N/A 3774 G /usr/lib/xorg/Xorg 4MiB |
| 2 N/A N/A 3774 G /usr/lib/xorg/Xorg 4MiB |
| 3 N/A N/A 3774 G /usr/lib/xorg/Xorg 4MiB |
+---------------------------------------------------------------------------------------+
pip show jax jaxlib
Name: jax
Version: 0.4.13
Summary: Differentiate, compile, and transform Numpy code.
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: [email protected]
License: Apache-2.0
Location: /home/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages
Requires: importlib-metadata, ml-dtypes, numpy, opt-einsum, scipy
Required-by:
---
Name: jaxlib
Version: 0.4.13+cuda12.cudnn89
Summary: XLA library for JAX
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: [email protected]
License: Apache-2.0
Location: /home/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages
Requires: ml-dtypes, numpy, scipy
Required-by:
It looks like you're running a really old version of JAX (0.4.13; the current version is 0.4.31). You'll probably need to update your Python version because 3.8 is no longer supported, and then install the most recent version of JAX (pip install -U "jax[cuda12]") to see if that does the trick!
@dfm Thank you for your response!
It looks like you're running a really old version of JAX (0.4.13; the current version is 0.4.31). You'll probably need to update your Python version because 3.8 is no longer supported, and then install the most recent version of JAX (
pip install -U "jax[cuda12]") to see if that does the trick!
I tried to upgrade into a new version but I only get this after running 'pip install -U "jax[cuda12]"':
pip show jax jaxlib WARNING: Package(s) not found: jaxlib Name: jax Version: 0.4.13 Summary: Differentiate, compile, and transform Numpy code. Home-page: https://github.com/google/jax Author: JAX team Author-email: [email protected] License: Apache-2.0 Location: /home/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages Requires: importlib-metadata, ml-dtypes, numpy, opt-einsum, scipy Required-by:
for some reason, it doesn't upgrade
Did you upgrade your Python version? Python 3.8 is no longer supported!
Given the 🎉 reaction, I'm going to hope that you got this sorted out on your end @MikaBell and close this issue. Please feel free to comment or open a new issue if the problem persists!