jax icon indicating copy to clipboard operation
jax copied to clipboard

cuSolver internal error with `jax.scipy.stats.multivariate_normal.pdf()`

Open jwknaup opened this issue 1 year ago • 0 comments

Description

jax.scipy.stats.multivariate_normal.pdf(0, 1, 1)

This works fine with scalar inputs. However, with ndarray arguments . . .

import jax
import jax.numpy as jnp
jax.scipy.stats.multivariate_normal.pdf(jnp.zeros(1), jnp.ones(1), jnp.eye(1))

I receive the following error

jax.scipy.stats.multivariate_normal.pdf(jnp.zeros(1), jnp.ones(1), jnp.eye(1))
E0626 18:17:20.343822   62859 pjrt_stream_executor_client.cc:2985] Execution of replica 0 failed: INTERNAL: cuSolver internal error
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/nvidia/jax/test_env/lib/python3.10/site-packages/jax/_src/scipy/stats/multivariate_normal.py", line 101, in pdf
    return lax.exp(logpdf(x, mean, cov))
  File "/home/nvidia/jax/test_env/lib/python3.10/site-packages/jax/_src/scipy/stats/multivariate_normal.py", line 66, in logpdf
    L = lax.linalg.cholesky(cov)
  File "/home/nvidia/jax/test_env/lib/python3.10/site-packages/jax/_src/lax/linalg.py", line 91, in cholesky
    return jnp.tril(cholesky_p.bind(x))
  File "/home/nvidia/jax/test_env/lib/python3.10/site-packages/jax/_src/core.py", line 416, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/home/nvidia/jax/test_env/lib/python3.10/site-packages/jax/_src/core.py", line 420, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/nvidia/jax/test_env/lib/python3.10/site-packages/jax/_src/core.py", line 921, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/nvidia/jax/test_env/lib/python3.10/site-packages/jax/_src/dispatch.py", line 87, in apply_primitive
    outs = fun(*args)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: cuSolver internal error

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.30
jaxlib: 0.4.30
numpy:  2.0.0
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='orin-nano-jp6', release='5.15.136-tegra', version='#1 SMP PREEMPT Mon May 6 09:56:39 PDT 2024', machine='aarch64')

I am using a Jetson Orin Nano. I installed jax using the prebuilt aarch64 wheels with prepackaged cuda_12 and cudnn. Perhaps that is the issue, although it is weird that most things seem to work okay.

jwknaup avatar Jun 26 '24 22:06 jwknaup