jax icon indicating copy to clipboard operation
jax copied to clipboard

cuSolver internal error on freshly installed cuda11.1 from conda-forge

Open PgLoLo opened this issue 3 years ago • 6 comments

In freshly installed python 3.9 environment with cuda11.1 and cudnn any call to jax.numpy.linalg.qr produceses and error RuntimeError: jaxlib/cusolver.cc:52: operation cusolverDnCreate(&handle) failed: cuSolver internal error.

Installation:

conda create -n test -c conda-forge python=3.9 cudatoolkit=11.1 cudnn
conda activate test
pip install --upgrade "jax[cuda111]" -f https://storage.googleapis.com/jax-releases/jax_releases.html  # Note: wheels only available on linux.
python
>>> import jax
>>> jax.numpy.linalg.qr(jax.numpy.ones([3, 3]))
2021-09-25 11:10:05.278742: W external/org_tensorflow/tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcusolver.so.11'; dlerror: libcusolver.so.11: cannot open shared object file: No such fi
le or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/gnovikov/data/miniconda3/envs/test/lib/python3.9/site-packages/jax/_src/numpy/linalg.py", line 468, in qr
    q, r = lax_linalg.qr(a, full_matrices)
  File "/home/gnovikov/data/miniconda3/envs/test/lib/python3.9/site-packages/jax/_src/lax/linalg.py", line 197, in qr
    q, r = qr_p.bind(x, full_matrices=full_matrices)
  File "/home/gnovikov/data/miniconda3/envs/test/lib/python3.9/site-packages/jax/core.py", line 267, in bind
    out = top_trace.process_primitive(self, tracers, params)
  File "/home/gnovikov/data/miniconda3/envs/test/lib/python3.9/site-packages/jax/core.py", line 612, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/gnovikov/data/miniconda3/envs/test/lib/python3.9/site-packages/jax/_src/lax/linalg.py", line 1092, in qr_impl
    q, r = xla.apply_primitive(qr_p, operand, full_matrices=full_matrices)
  File "/home/gnovikov/data/miniconda3/envs/test/lib/python3.9/site-packages/jax/interpreters/xla.py", line 275, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
  File "/home/gnovikov/data/miniconda3/envs/test/lib/python3.9/site-packages/jax/_src/util.py", line 195, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/home/gnovikov/data/miniconda3/envs/test/lib/python3.9/site-packages/jax/_src/util.py", line 188, in cached
    return f(*args, **kwargs)
  File "/home/gnovikov/data/miniconda3/envs/test/lib/python3.9/site-packages/jax/interpreters/xla.py", line 317, in xla_primitive_callable
    built_c = primitive_computation(prim, AxisEnv(nreps, (), ()), backend,
  File "/home/gnovikov/data/miniconda3/envs/test/lib/python3.9/site-packages/jax/_src/util.py", line 195, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/home/gnovikov/data/miniconda3/envs/test/lib/python3.9/site-packages/jax/_src/util.py", line 188, in cached
    return f(*args, **kwargs)
  File "/home/gnovikov/data/miniconda3/envs/test/lib/python3.9/site-packages/jax/interpreters/xla.py", line 357, in primitive_computation
    ans = rule(c, *xla_args, **params)
  File "/home/gnovikov/data/miniconda3/envs/test/lib/python3.9/site-packages/jax/_src/lax/linalg.py", line 1144, in _qr_cpu_gpu_translation_rule
    r, tau, info_geqrf = geqrf_impl(c, operand)
  File "/home/gnovikov/data/miniconda3/envs/test/lib/python3.9/site-packages/jaxlib/cusolver.py", line 200, in geqrf
    lwork, opaque = cusolver_kernels.build_geqrf_descriptor(
RuntimeError: jaxlib/cusolver.cc:52: operation cusolverDnCreate(&handle) failed: cuSolver internal error

I am fairly sure, that QR-decomposition is not the only one that would produce an error in the considered setup. The same thing with cudatoolkit=10.2 (still from conda) works OK. Same thing on another machine with cuda11.1 not from conda works as well.

PgLoLo avatar Sep 25 '21 11:09 PgLoLo

I am encountering the same issue with cudatoolkit 11.3.1, cudatoolkit-dev 11.3.1 and cudnn 8.2.1 from conda/conda-forge. It fails at the eigh function, though, as @PgLoLo already mentioned it's very likely not the only broken function.

n-gao avatar Oct 18 '21 10:10 n-gao

I tried both the libraries from conda/conda-forge and the libraries from nvidia via conda.

n-gao avatar Oct 18 '21 10:10 n-gao

While I had the same issue, I solved it by setting LD_LIBRARY_PATH .

export LD_LIBRARY_PATH=/path/to/miniconda3/envs/{your_env_name}/lib
python -c "import jax; jax.numpy.linalg.qr(jax.numpy.ones([3,3]))"
# works fine

k-khr avatar Oct 19 '21 15:10 k-khr

@k-khr that works for me, thanks!

n-gao avatar Oct 20 '21 07:10 n-gao

@PgLoLo was this resolved?

sudhakarsingh27 avatar Aug 12 '22 19:08 sudhakarsingh27

One has to manually set the LD_LIBRARY_PATH, is there a better solution for this? I personally set it up in my conda environment via conda env config vars set LD_LIBRARY_PATH=.... But, this has to be done for each new environment and quickly gets forgotten if one creates a new environment.

n-gao avatar Sep 08 '22 08:09 n-gao