jax
jax copied to clipboard
cuSolver is not initialised error
I gotRuntimeError: cuSolver has not been initialized while running the following code. Thanks for any help in advance!
import jax
import jaxlib
import jax.numpy as jnp
print(jax.__version__) # 0.2.6
print(jaxlib.__version__) # 0.1.57
jax.grad(jnp.linalg.det)(jnp.eye(2))
other set up info:
- CUDA Toolkit: 10.1
- nvidia driver: 430.34
Full error stack trace with additional logs from tensorflow:
2020-12-24 10:00:28.586780: I external/org_tensorflow/tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.10.1
2020-12-24 10:00:28.656856: I external/org_tensorflow/tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.10.1
0.2.6
0.1.57
2020-12-24 10:00:29.051373: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:168] XLA service 0x5624f9c183b0 initialized for platform Interpreter (this does not guarantee that XLA will be used). Devices:
2020-12-24 10:00:29.051404: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:176] StreamExecutor device (0): Interpreter, <undefined>
2020-12-24 10:00:29.054250: I external/org_tensorflow/tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 2100000000 Hz
2020-12-24 10:00:29.068014: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:168] XLA service 0x5624fa256fc0 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2020-12-24 10:00:29.068049: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:176] StreamExecutor device (0): Host, Default Version
2020-12-24 10:00:29.071225: I external/org_tensorflow/tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcuda.so.1
2020-12-24 10:00:29.199562: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:168] XLA service 0x5624fa370e90 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2020-12-24 10:00:29.199655: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:176] StreamExecutor device (0): Quadro P4000, Compute Capability 6.1
2020-12-24 10:00:29.200672: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc:119] XLA backend allocating 2953130803 bytes on device 0 for BFCAllocator.
2020-12-24 10:00:29.201539: I external/org_tensorflow/tensorflow/stream_executor/tpu/tpu_platform_interface.cc:74] No TPU platform found.
2020-12-24 10:00:29.677721: I external/org_tensorflow/tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcusolver.so.10
Traceback (most recent call last):
File "/volper/users/gaowang/fqx/gp_prune/gppr/bug_script.py", line 8, in <module>
jax.grad(jnp.linalg.det)(jnp.eye(2))
File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/_src/numpy/linalg.py", line 250, in _det_jvp
y, z = _cofactor_solve(x, g)
File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/_src/numpy/linalg.py", line 207, in _cofactor_solve
lu, pivots, permutation = lax_linalg.lu(a)
File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/_src/lax/linalg.py", line 143, in lu
lu, pivots, permutation = lu_p.bind(x)
File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/_src/lax/linalg.py", line 745, in _lu_impl
lu, pivot, perm = xla.apply_primitive(lu_p, operand)
File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/_src/lax/linalg.py", line 821, in _lu_cpu_gpu_translation_rule
lu, pivot, info = getrf_impl(c, operand)
File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jaxlib/cusolver.py", line 156, in getrf
lwork, opaque = cusolver_kernels.build_getrf_descriptor(
jax._src.traceback_util.FilteredStackTrace: RuntimeError: cuSolver has not been initialized
The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/volper/users/gaowang/fqx/gp_prune/gppr/bug_script.py", line 8, in <module>
jax.grad(jnp.linalg.det)(jnp.eye(2))
File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 133, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/api.py", line 706, in grad_f
_, g = value_and_grad_f(*args, **kwargs)
File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 133, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/api.py", line 769, in value_and_grad_f
ans, vjp_py = _vjp(f_partial, *dyn_args)
File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/api.py", line 1796, in _vjp
out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/interpreters/ad.py", line 113, in vjp
out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/interpreters/ad.py", line 100, in linearize
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 488, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/linear_util.py", line 160, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/custom_derivatives.py", line 212, in __call__
out_flat = custom_jvp_call_p.bind(flat_fun, flat_jvp, *args_flat)
File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/custom_derivatives.py", line 278, in bind
outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers) # type: ignore
File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/interpreters/ad.py", line 325, in process_custom_jvp_call
outs = f_jvp.call_wrapped(*it.chain(primals_in, tangents_in))
File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/linear_util.py", line 160, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/_src/numpy/linalg.py", line 250, in _det_jvp
y, z = _cofactor_solve(x, g)
File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/_src/numpy/linalg.py", line 207, in _cofactor_solve
lu, pivots, permutation = lax_linalg.lu(a)
File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/_src/lax/linalg.py", line 143, in lu
lu, pivots, permutation = lu_p.bind(x)
File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/core.py", line 270, in bind
out = top_trace.process_primitive(self, tracers, params)
File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/core.py", line 580, in process_primitive
return primitive.impl(*tracers, **params)
File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/_src/lax/linalg.py", line 745, in _lu_impl
lu, pivot, perm = xla.apply_primitive(lu_p, operand)
File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/interpreters/xla.py", line 235, in apply_primitive
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/interpreters/xla.py", line 278, in xla_primitive_callable
built_c = primitive_computation(prim, AxisEnv(nreps, (), (), None), backend,
File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/interpreters/xla.py", line 319, in primitive_computation
ans = rule(c, *xla_args, **params)
File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jax/_src/lax/linalg.py", line 821, in _lu_cpu_gpu_translation_rule
lu, pivot, info = getrf_impl(c, operand)
File "/volper/users/gaowang/anaconda3/envs/gppr/lib/python3.8/site-packages/jaxlib/cusolver.py", line 156, in getrf
lwork, opaque = cusolver_kernels.build_getrf_descriptor(
RuntimeError: cuSolver has not been initialized
Did you install the jaxlib with GPU support?
https://github.com/google/jax#installation
I still have this issue, does anyone have the solution?
@qixuanf was this resolved?
Closing since no response. (feel free to open to again if the issue isn't resolved at your end)
I also met this issue. RuntimeError: jaxlib/gpu/solver_kernels.cc:45: operation gpusolverDnCreate(&handle) failed: cuSolver has not been initialized. Have you already solved it?
I also met this issue.
RuntimeError: jaxlib/gpu/solver_kernels.cc:45: operation gpusolverDnCreate(&handle) failed: cuSolver has not been initialized. Have you already solved it?
I met the same problem. Have you solved it? Thanks.
This message almost certainly means you ran out of GPU memory.
Please read: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
This message almost certainly means you ran out of GPU memory.
Please read: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
Thank you. But it seems that the peak usage of my GPU memory still has over 2GiB remaining, is there any other possible reasons?