jax icon indicating copy to clipboard operation
jax copied to clipboard

CuDNN Initialization Error with Pre-existing CUDA Environment Fails with

Open parambole opened this issue 6 months ago • 5 comments

Description

Our a cluster node which already has Nvidia drivers and the CUDA toolkit installed (to maintain version compatibility with the underlying OS and the networking stack).

Installing via jax[cuda12_local], would make sense. But, as mentioned in the installation guide, JAX requires CUDNN. This would entail that additional packages need to be installed (assuming JAX and the additional packages are installed in the Docker container).

But it fails with an error. So, is the recommended practice to always install with jax[cuda12] even though CUDA is already installed?

Note: I am setting the LD_LIBRARY_PATH to the location where cuDNN is installed using pip install nvidia-cudnn-cu12.

Error

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/workspace/jax_pingpong.py", line 49, in <module>
    init_processes()
  File "/workspace/jax_pingpong.py", line 44, in init_processes
    run()
  File "/workspace/jax_pingpong.py", line 24, in run
    xs = jax.numpy.ones(jax.local_device_count())
  File "/usr/local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 3883, in ones
    return lax.full(shape, 1, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device))
  File "/usr/local/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 1302, in full
    fill_value = _convert_element_type(fill_value, dtype, weak_type)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 566, in _convert_element_type
    return convert_element_type_p.bind(
  File "/usr/local/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 2559, in _convert_element_type_bind
    operand = core.Primitive.bind(convert_element_type_p, operand,
  File "/usr/local/lib/python3.10/site-packages/jax/_src/core.py", line 429, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/core.py", line 433, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/core.py", line 939, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/dispatch.py", line 87, in apply_primitive
    outs = fun(*args)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors

Note: jax.devices() gives the expected output

System info (python version, jaxlib version, accelerator, etc.)

  1. Accelerator: GPU
  2. CUDA-12
  3. cuDNN > 9.0

parambole avatar Aug 16 '24 21:08 parambole