jax icon indicating copy to clipboard operation
jax copied to clipboard

cuSolver is not initialised error

Open qixuanf opened this issue 4 years ago • 2 comments

I gotRuntimeError: cuSolver has not been initialized while running the following code. Thanks for any help in advance!

import jax
import jaxlib
import jax.numpy as jnp

print(jax.__version__)  # 0.2.6
print(jaxlib.__version__)  # 0.1.57

jax.grad(jnp.linalg.det)(jnp.eye(2))

other set up info:

  • CUDA Toolkit: 10.1
  • nvidia driver: 430.34

Full error stack trace with additional logs from tensorflow:

2020-12-24 10:00:28.586780: I external/org_tensorflow/tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.10.1
2020-12-24 10:00:28.656856: I external/org_tensorflow/tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.10.1
0.2.6
0.1.57
2020-12-24 10:00:29.051373: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:168] XLA service 0x5624f9c183b0 initialized for platform Interpreter (this does not guarantee that XLA will be used). Devices:
2020-12-24 10:00:29.051404: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Interpreter, <undefined>
2020-12-24 10:00:29.054250: I external/org_tensorflow/tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 2100000000 Hz
2020-12-24 10:00:29.068014: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:168] XLA service 0x5624fa256fc0 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2020-12-24 10:00:29.068049: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
2020-12-24 10:00:29.071225: I external/org_tensorflow/tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcuda.so.1
2020-12-24 10:00:29.199562: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:168] XLA service 0x5624fa370e90 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2020-12-24 10:00:29.199655: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Quadro P4000, Compute Capability 6.1
2020-12-24 10:00:29.200672: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc:119] XLA backend allocating 2953130803 bytes on device 0 for BFCAllocator.
2020-12-24 10:00:29.201539: I external/org_tensorflow/tensorflow/stream_executor/tpu/tpu_platform_interface.cc:74] No TPU platform found.
2020-12-24 10:00:29.677721: I external/org_tensorflow/tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcusolver.so.10
Traceback (most recent call last):
  File "/volper/users/gaowang/fqx/gp_prune/gppr/bug_script.py", line 8, in <module>
    jax.grad(jnp.linalg.det)(jnp.eye(2))
  File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/_src/numpy/linalg.py", line 250, in _det_jvp
    y, z = _cofactor_solve(x, g)
  File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/_src/numpy/linalg.py", line 207, in _cofactor_solve
    lu, pivots, permutation = lax_linalg.lu(a)
  File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/_src/lax/linalg.py", line 143, in lu
    lu, pivots, permutation = lu_p.bind(x)
  File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/_src/lax/linalg.py", line 745, in _lu_impl
    lu, pivot, perm = xla.apply_primitive(lu_p, operand)
  File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/_src/lax/linalg.py", line 821, in _lu_cpu_gpu_translation_rule
    lu, pivot, info = getrf_impl(c, operand)
  File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jaxlib/cusolver.py", line 156, in getrf
    lwork, opaque = cusolver_kernels.build_getrf_descriptor(
jax._src.traceback_util.FilteredStackTrace: RuntimeError: cuSolver has not been initialized

The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/volper/users/gaowang/fqx/gp_prune/gppr/bug_script.py", line 8, in <module>
    jax.grad(jnp.linalg.det)(jnp.eye(2))
  File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 133, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/api.py", line 706, in grad_f
    _, g = value_and_grad_f(*args, **kwargs)
  File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 133, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/api.py", line 769, in value_and_grad_f
    ans, vjp_py = _vjp(f_partial, *dyn_args)
  File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/api.py", line 1796, in _vjp
    out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
  File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/interpreters/ad.py", line 113, in vjp
    out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
  File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/interpreters/ad.py", line 100, in linearize
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
  File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 488, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/linear_util.py", line 160, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/custom_derivatives.py", line 212, in __call__
    out_flat = custom_jvp_call_p.bind(flat_fun, flat_jvp, *args_flat)
  File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/custom_derivatives.py", line 278, in bind
    outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers)  # type: ignore
  File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/interpreters/ad.py", line 325, in process_custom_jvp_call
    outs = f_jvp.call_wrapped(*it.chain(primals_in, tangents_in))
  File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/linear_util.py", line 160, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/_src/numpy/linalg.py", line 250, in _det_jvp
    y, z = _cofactor_solve(x, g)
  File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/_src/numpy/linalg.py", line 207, in _cofactor_solve
    lu, pivots, permutation = lax_linalg.lu(a)
  File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/_src/lax/linalg.py", line 143, in lu
    lu, pivots, permutation = lu_p.bind(x)
  File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/core.py", line 270, in bind
    out = top_trace.process_primitive(self, tracers, params)
  File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/core.py", line 580, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/_src/lax/linalg.py", line 745, in _lu_impl
    lu, pivot, perm = xla.apply_primitive(lu_p, operand)
  File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/interpreters/xla.py", line 235, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
  File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/interpreters/xla.py", line 278, in xla_primitive_callable
    built_c = primitive_computation(prim, AxisEnv(nreps, (), (), None), backend,
  File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/interpreters/xla.py", line 319, in primitive_computation
    ans = rule(c, *xla_args, **params)
  File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/_src/lax/linalg.py", line 821, in _lu_cpu_gpu_translation_rule
    lu, pivot, info = getrf_impl(c, operand)
  File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jaxlib/cusolver.py", line 156, in getrf
    lwork, opaque = cusolver_kernels.build_getrf_descriptor(
RuntimeError: cuSolver has not been initialized

qixuanf avatar Dec 24 '20 09:12 qixuanf

Did you install the jaxlib with GPU support?

https://github.com/google/jax#installation

zhangqiaorjc avatar Jan 12 '21 21:01 zhangqiaorjc

I still have this issue, does anyone have the solution?

RorroArt avatar Feb 25 '21 09:02 RorroArt

@qixuanf was this resolved?

sudhakarsingh27 avatar Aug 12 '22 19:08 sudhakarsingh27

Closing since no response. (feel free to open to again if the issue isn't resolved at your end)

sudhakarsingh27 avatar Sep 14 '22 18:09 sudhakarsingh27

I also met this issue. RuntimeError: jaxlib/gpu/solver_kernels.cc:45: operation gpusolverDnCreate(&handle) failed: cuSolver has not been initialized. Have you already solved it?

Schortenger avatar Jun 09 '23 05:06 Schortenger

I also met this issue. RuntimeError: jaxlib/gpu/solver_kernels.cc:45: operation gpusolverDnCreate(&handle) failed: cuSolver has not been initialized. Have you already solved it?

I met the same problem. Have you solved it? Thanks.

zsc2003 avatar Sep 21 '23 15:09 zsc2003

This message almost certainly means you ran out of GPU memory.

Please read: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html

hawkinsp avatar Sep 21 '23 15:09 hawkinsp

This message almost certainly means you ran out of GPU memory.

Please read: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html

Thank you. But it seems that the peak usage of my GPU memory still has over 2GiB remaining, is there any other possible reasons?

zsc2003 avatar Sep 22 '23 07:09 zsc2003