jax icon indicating copy to clipboard operation
jax copied to clipboard

jax library error jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed

Open MikaBell opened this issue 1 year ago • 3 comments

Description

Hey, I am working on a code that uses the Jax library, and I run into this error over and over again no matter how I tried to configure my environment:

2024-08-20 16:26:58.037892: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:439] 

Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2024-08-20 16:26:58.037952: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:443] Memory usage: 46637514752 bytes free, 47587131392 bytes total.
Traceback (most recent call last):
  File "GPU_pairwise_pipline.py", line 260, in <module>
    SSMD_res_with_indices = process_blocks(train_set_sick, train_set_healthy, block_size)
  File "GPU_pairwise_pipline.py", line 172, in process_blocks
    mean_block1_sick, var_block1_sick = cal_mean_and_var(block1_sick)
  File "GPU_pairwise_pipline.py", line 17, in cal_mean_and_var
    data_jax = jnp.array(data)
  File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 2035, in array
    out = _array_copy(object) if copy else object
  File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 4447, in _array_copy
    return copy_p.bind(arr)
  File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/core.py", line 380, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/core.py", line 383, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/core.py", line 815, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 4486, in _copy_impl
    return dispatch.apply_primitive(prim, *args, **kwargs)
  File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/dispatch.py", line 132, in apply_primitive
    compiled_fun = xla_primitive_callable(
  File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/util.py", line 284, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/util.py", line 277, in cached
    return f(*args, **kwargs)
  File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/dispatch.py", line 223, in xla_primitive_callable
    compiled = _xla_callable_uncached(
  File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/dispatch.py", line 253, in _xla_callable_uncached
    return computation.compile().unsafe_call
  File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2323, in compile
    executable = UnloadedMeshExecutable.from_hlo(
  File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2645, in from_hlo
    xla_executable, compile_options = _cached_compilation(
  File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2555, in _cached_compilation
    xla_executable = dispatch.compile_or_get_cached(
  File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/dispatch.py", line 497, in compile_or_get_cached
    return backend_compile(backend, computation, compile_options,
  File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/dispatch.py", line 465, in backend_compile
    return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

There is no memory problem, and I have set up my $LD_LIBRARY_PATH to point to where my CUDNN version I downloaded is:

echo $LD_LIBRARY_PATH
/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/


#define CUDNN_MAJOR 8
#define CUDNN_MINOR 9

Please I would appreciate someone's help in understanding why I get the same error over and over again..

Thank you!

System info (python version, jaxlib version, accelerator, etc.)

nvidia-smi
Tue Aug 20 16:52:27 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08              Driver Version: 545.23.08    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA L40                     On  | 00000000:01:00.0 Off |                    0 |
| N/A   39C    P0              79W / 300W |    894MiB / 46068MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA L40                     On  | 00000000:02:00.0 Off |                    0 |
| N/A   29C    P8              34W / 300W |     21MiB / 46068MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   2  NVIDIA L40                     On  | 00000000:61:00.0 Off |                    0 |
| N/A   29C    P8              34W / 300W |     21MiB / 46068MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   3  NVIDIA L40                     On  | 00000000:62:00.0 Off |                    0 |
| N/A   30C    P8              35W / 300W |     21MiB / 46068MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A      3774      G   /usr/lib/xorg/Xorg                            4MiB |
|    0   N/A  N/A    999897      C   ...physics/rtk243/anaconda3/bin/python      868MiB |
|    1   N/A  N/A      3774      G   /usr/lib/xorg/Xorg                            4MiB |
|    2   N/A  N/A      3774      G   /usr/lib/xorg/Xorg                            4MiB |
|    3   N/A  N/A      3774      G   /usr/lib/xorg/Xorg                            4MiB |
+---------------------------------------------------------------------------------------+
pip show jax jaxlib
Name: jax
Version: 0.4.13
Summary: Differentiate, compile, and transform Numpy code.
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: [email protected]
License: Apache-2.0
Location: /home/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages
Requires: importlib-metadata, ml-dtypes, numpy, opt-einsum, scipy
Required-by:
---
Name: jaxlib
Version: 0.4.13+cuda12.cudnn89
Summary: XLA library for JAX
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: [email protected]
License: Apache-2.0
Location: /home/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages
Requires: ml-dtypes, numpy, scipy
Required-by:

MikaBell avatar Aug 20 '24 14:08 MikaBell

It looks like you're running a really old version of JAX (0.4.13; the current version is 0.4.31). You'll probably need to update your Python version because 3.8 is no longer supported, and then install the most recent version of JAX (pip install -U "jax[cuda12]") to see if that does the trick!

dfm avatar Aug 20 '24 15:08 dfm

@dfm Thank you for your response!

It looks like you're running a really old version of JAX (0.4.13; the current version is 0.4.31). You'll probably need to update your Python version because 3.8 is no longer supported, and then install the most recent version of JAX (pip install -U "jax[cuda12]") to see if that does the trick!

I tried to upgrade into a new version but I only get this after running 'pip install -U "jax[cuda12]"':

pip show jax jaxlib WARNING: Package(s) not found: jaxlib Name: jax Version: 0.4.13 Summary: Differentiate, compile, and transform Numpy code. Home-page: https://github.com/google/jax Author: JAX team Author-email: [email protected] License: Apache-2.0 Location: /home/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages Requires: importlib-metadata, ml-dtypes, numpy, opt-einsum, scipy Required-by:

for some reason, it doesn't upgrade

MikaBell avatar Aug 20 '24 15:08 MikaBell

Did you upgrade your Python version? Python 3.8 is no longer supported!

dfm avatar Aug 20 '24 15:08 dfm

Given the 🎉 reaction, I'm going to hope that you got this sorted out on your end @MikaBell and close this issue. Please feel free to comment or open a new issue if the problem persists!

dfm avatar Sep 03 '24 10:09 dfm