jax
jax copied to clipboard
CuDNN Initialization Error with Pre-existing CUDA Environment Fails with
Description
Our a cluster node which already has Nvidia drivers and the CUDA toolkit installed (to maintain version compatibility with the underlying OS and the networking stack).
Installing via jax[cuda12_local], would make sense. But, as mentioned in the installation guide, JAX requires CUDNN. This would entail that additional packages need to be installed (assuming JAX and the additional packages are installed in the Docker container).
But it fails with an error. So, is the recommended practice to always install with jax[cuda12] even though CUDA is already installed?
Note:
I am setting the LD_LIBRARY_PATH
to the location where cuDNN is installed using pip install nvidia-cudnn-cu12.
Error
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/workspace/jax_pingpong.py", line 49, in <module>
init_processes()
File "/workspace/jax_pingpong.py", line 44, in init_processes
run()
File "/workspace/jax_pingpong.py", line 24, in run
xs = jax.numpy.ones(jax.local_device_count())
File "/usr/local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 3883, in ones
return lax.full(shape, 1, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device))
File "/usr/local/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 1302, in full
fill_value = _convert_element_type(fill_value, dtype, weak_type)
File "/usr/local/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 566, in _convert_element_type
return convert_element_type_p.bind(
File "/usr/local/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 2559, in _convert_element_type_bind
operand = core.Primitive.bind(convert_element_type_p, operand,
File "/usr/local/lib/python3.10/site-packages/jax/_src/core.py", line 429, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/usr/local/lib/python3.10/site-packages/jax/_src/core.py", line 433, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/usr/local/lib/python3.10/site-packages/jax/_src/core.py", line 939, in process_primitive
return primitive.impl(*tracers, **params)
File "/usr/local/lib/python3.10/site-packages/jax/_src/dispatch.py", line 87, in apply_primitive
outs = fun(*args)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors
Note: jax.devices() gives the expected output
System info (python version, jaxlib version, accelerator, etc.)
- Accelerator: GPU
- CUDA-12
- cuDNN > 9.0