jax icon indicating copy to clipboard operation
jax copied to clipboard

CuDNN Initialization Error with Pre-existing CUDA Environment Fails with

Open parambole opened this issue 1 year ago • 5 comments

Description

Our a cluster node which already has Nvidia drivers and the CUDA toolkit installed (to maintain version compatibility with the underlying OS and the networking stack).

Installing via jax[cuda12_local], would make sense. But, as mentioned in the installation guide, JAX requires CUDNN. This would entail that additional packages need to be installed (assuming JAX and the additional packages are installed in the Docker container).

But it fails with an error. So, is the recommended practice to always install with jax[cuda12] even though CUDA is already installed?

Note: I am setting the LD_LIBRARY_PATH to the location where cuDNN is installed using pip install nvidia-cudnn-cu12.

Error

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/workspace/jax_pingpong.py", line 49, in <module>
    init_processes()
  File "/workspace/jax_pingpong.py", line 44, in init_processes
    run()
  File "/workspace/jax_pingpong.py", line 24, in run
    xs = jax.numpy.ones(jax.local_device_count())
  File "/usr/local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 3883, in ones
    return lax.full(shape, 1, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device))
  File "/usr/local/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 1302, in full
    fill_value = _convert_element_type(fill_value, dtype, weak_type)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 566, in _convert_element_type
    return convert_element_type_p.bind(
  File "/usr/local/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 2559, in _convert_element_type_bind
    operand = core.Primitive.bind(convert_element_type_p, operand,
  File "/usr/local/lib/python3.10/site-packages/jax/_src/core.py", line 429, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/core.py", line 433, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/core.py", line 939, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/dispatch.py", line 87, in apply_primitive
    outs = fun(*args)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors

Note: jax.devices() gives the expected output

System info (python version, jaxlib version, accelerator, etc.)

  1. Accelerator: GPU
  2. CUDA-12
  3. cuDNN > 9.0

parambole avatar Aug 16 '24 21:08 parambole

If you have a CUDA already installed and you want to use that, but cudnn (or some other libs are missing). I think you should do: jax[cuda12_local] but also install the other missing packages. You could do that via pip to keep it simple. For cudnn: nvidia-cudnn-cu12 Here is others wheel that you could need:

          "nvidia-cublas-cu12",
          "nvidia-cuda-nvcc-cu12",
          "nvidia-cuda-runtime-cu12",
          "nvidia-cufft-cu12",
          "nvidia-cusolver-cu12",
          "nvidia-cusparse-cu12",

nouiz avatar Aug 17 '24 14:08 nouiz

@nouiz, are the above packages not installed by the CUDA toolkit or during the cuDNN installation? Is there a document that mentions the list of required packages (either through pip or otherwise) for the end-to-end flow to work?

parambole avatar Aug 20 '24 18:08 parambole

You don't have a normal setup. There isn't doc for it. Some of the packages above are provided by the cuda sdk. But cudnn and nccl isn't. Can you try installing only cudnn and nccl packages? If that don't work, report the error.

nouiz avatar Aug 21 '24 00:08 nouiz

@nouiz Installing the packages that you mentioned in the previous comment resolved the issue

 pip install -U jax[cuda12_local]
 pip install 'nvidia-cudnn-cu12>=9.1.0,<10.0.0'
 pip install "nvidia-cublas-cu12"
 pip install "nvidia-cuda-nvcc-cu12"
 pip install "nvidia-cuda-runtime-cu12"
 pip install "nvidia-cufft-cu12"
 pip install "nvidia-cusolver-cu12"
 pip install "nvidia-cusparse-cu12"

parambole avatar Aug 21 '24 17:08 parambole

You probably installed too many packages. I don't think this is an issue. Should we close this issue?

nouiz avatar Aug 21 '24 18:08 nouiz