jax
jax copied to clipboard
cuSolver internal error on freshly installed cuda11.1 from conda-forge
In freshly installed python 3.9 environment with cuda11.1 and cudnn any call to jax.numpy.linalg.qr
produceses and error RuntimeError: jaxlib/cusolver.cc:52: operation cusolverDnCreate(&handle) failed: cuSolver internal error
.
Installation:
conda create -n test -c conda-forge python=3.9 cudatoolkit=11.1 cudnn
conda activate test
pip install --upgrade "jax[cuda111]" -f https://storage.googleapis.com/jax-releases/jax_releases.html # Note: wheels only available on linux.
python
>>> import jax
>>> jax.numpy.linalg.qr(jax.numpy.ones([3, 3]))
2021-09-25 11:10:05.278742: W external/org_tensorflow/tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcusolver.so.11'; dlerror: libcusolver.so.11: cannot open shared object file: No such fi
le or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/gnovikov/data/miniconda3/envs/test/lib/python3.9/site-packages/jax/_src/numpy/linalg.py", line 468, in qr
q, r = lax_linalg.qr(a, full_matrices)
File "/home/gnovikov/data/miniconda3/envs/test/lib/python3.9/site-packages/jax/_src/lax/linalg.py", line 197, in qr
q, r = qr_p.bind(x, full_matrices=full_matrices)
File "/home/gnovikov/data/miniconda3/envs/test/lib/python3.9/site-packages/jax/core.py", line 267, in bind
out = top_trace.process_primitive(self, tracers, params)
File "/home/gnovikov/data/miniconda3/envs/test/lib/python3.9/site-packages/jax/core.py", line 612, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/gnovikov/data/miniconda3/envs/test/lib/python3.9/site-packages/jax/_src/lax/linalg.py", line 1092, in qr_impl
q, r = xla.apply_primitive(qr_p, operand, full_matrices=full_matrices)
File "/home/gnovikov/data/miniconda3/envs/test/lib/python3.9/site-packages/jax/interpreters/xla.py", line 275, in apply_primitive
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
File "/home/gnovikov/data/miniconda3/envs/test/lib/python3.9/site-packages/jax/_src/util.py", line 195, in wrapper
return cached(config._trace_context(), *args, **kwargs)
File "/home/gnovikov/data/miniconda3/envs/test/lib/python3.9/site-packages/jax/_src/util.py", line 188, in cached
return f(*args, **kwargs)
File "/home/gnovikov/data/miniconda3/envs/test/lib/python3.9/site-packages/jax/interpreters/xla.py", line 317, in xla_primitive_callable
built_c = primitive_computation(prim, AxisEnv(nreps, (), ()), backend,
File "/home/gnovikov/data/miniconda3/envs/test/lib/python3.9/site-packages/jax/_src/util.py", line 195, in wrapper
return cached(config._trace_context(), *args, **kwargs)
File "/home/gnovikov/data/miniconda3/envs/test/lib/python3.9/site-packages/jax/_src/util.py", line 188, in cached
return f(*args, **kwargs)
File "/home/gnovikov/data/miniconda3/envs/test/lib/python3.9/site-packages/jax/interpreters/xla.py", line 357, in primitive_computation
ans = rule(c, *xla_args, **params)
File "/home/gnovikov/data/miniconda3/envs/test/lib/python3.9/site-packages/jax/_src/lax/linalg.py", line 1144, in _qr_cpu_gpu_translation_rule
r, tau, info_geqrf = geqrf_impl(c, operand)
File "/home/gnovikov/data/miniconda3/envs/test/lib/python3.9/site-packages/jaxlib/cusolver.py", line 200, in geqrf
lwork, opaque = cusolver_kernels.build_geqrf_descriptor(
RuntimeError: jaxlib/cusolver.cc:52: operation cusolverDnCreate(&handle) failed: cuSolver internal error
I am fairly sure, that QR-decomposition is not the only one that would produce an error in the considered setup. The same thing with cudatoolkit=10.2 (still from conda) works OK. Same thing on another machine with cuda11.1 not from conda works as well.
I am encountering the same issue with cudatoolkit 11.3.1, cudatoolkit-dev 11.3.1 and cudnn 8.2.1 from conda/conda-forge. It fails at the eigh
function, though, as @PgLoLo already mentioned it's very likely not the only broken function.
I tried both the libraries from conda/conda-forge and the libraries from nvidia via conda.
While I had the same issue, I solved it by setting LD_LIBRARY_PATH
.
export LD_LIBRARY_PATH=/path/to/miniconda3/envs/{your_env_name}/lib
python -c "import jax; jax.numpy.linalg.qr(jax.numpy.ones([3,3]))"
# works fine
@k-khr that works for me, thanks!
@PgLoLo was this resolved?
One has to manually set the LD_LIBRARY_PATH
, is there a better solution for this?
I personally set it up in my conda environment via conda env config vars set LD_LIBRARY_PATH=...
. But, this has to be done for each new environment and quickly gets forgotten if one creates a new environment.