jax
jax copied to clipboard
CuDNN Initialization Error with Pre-existing CUDA Environment Fails with
Description
Our a cluster node which already has Nvidia drivers and the CUDA toolkit installed (to maintain version compatibility with the underlying OS and the networking stack).
Installing via jax[cuda12_local], would make sense. But, as mentioned in the installation guide, JAX requires CUDNN. This would entail that additional packages need to be installed (assuming JAX and the additional packages are installed in the Docker container).
But it fails with an error. So, is the recommended practice to always install with jax[cuda12] even though CUDA is already installed?
Note:
I am setting the LD_LIBRARY_PATH to the location where cuDNN is installed using pip install nvidia-cudnn-cu12.
Error
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/workspace/jax_pingpong.py", line 49, in <module>
init_processes()
File "/workspace/jax_pingpong.py", line 44, in init_processes
run()
File "/workspace/jax_pingpong.py", line 24, in run
xs = jax.numpy.ones(jax.local_device_count())
File "/usr/local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 3883, in ones
return lax.full(shape, 1, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device))
File "/usr/local/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 1302, in full
fill_value = _convert_element_type(fill_value, dtype, weak_type)
File "/usr/local/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 566, in _convert_element_type
return convert_element_type_p.bind(
File "/usr/local/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 2559, in _convert_element_type_bind
operand = core.Primitive.bind(convert_element_type_p, operand,
File "/usr/local/lib/python3.10/site-packages/jax/_src/core.py", line 429, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/usr/local/lib/python3.10/site-packages/jax/_src/core.py", line 433, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/usr/local/lib/python3.10/site-packages/jax/_src/core.py", line 939, in process_primitive
return primitive.impl(*tracers, **params)
File "/usr/local/lib/python3.10/site-packages/jax/_src/dispatch.py", line 87, in apply_primitive
outs = fun(*args)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors
Note: jax.devices() gives the expected output
System info (python version, jaxlib version, accelerator, etc.)
- Accelerator: GPU
- CUDA-12
- cuDNN > 9.0
If you have a CUDA already installed and you want to use that, but cudnn (or some other libs are missing). I think you should do: jax[cuda12_local] but also install the other missing packages.
You could do that via pip to keep it simple. For cudnn: nvidia-cudnn-cu12
Here is others wheel that you could need:
"nvidia-cublas-cu12",
"nvidia-cuda-nvcc-cu12",
"nvidia-cuda-runtime-cu12",
"nvidia-cufft-cu12",
"nvidia-cusolver-cu12",
"nvidia-cusparse-cu12",
@nouiz, are the above packages not installed by the CUDA toolkit or during the cuDNN installation? Is there a document that mentions the list of required packages (either through pip or otherwise) for the end-to-end flow to work?
You don't have a normal setup. There isn't doc for it. Some of the packages above are provided by the cuda sdk. But cudnn and nccl isn't. Can you try installing only cudnn and nccl packages? If that don't work, report the error.
@nouiz Installing the packages that you mentioned in the previous comment resolved the issue
pip install -U jax[cuda12_local]
pip install 'nvidia-cudnn-cu12>=9.1.0,<10.0.0'
pip install "nvidia-cublas-cu12"
pip install "nvidia-cuda-nvcc-cu12"
pip install "nvidia-cuda-runtime-cu12"
pip install "nvidia-cufft-cu12"
pip install "nvidia-cusolver-cu12"
pip install "nvidia-cusparse-cu12"
You probably installed too many packages. I don't think this is an issue. Should we close this issue?