jax
jax copied to clipboard
BUG: jax.scipy.stats.multivariate_normal.logpdf returns nan for high dimensionality matrices
Description
The following code example works perfectly for ndim = 10, but Jax returns nan for ndim=100. It stars failing at ndim=15. I am working in double precision by setting the environment variable JAX_ENABLE_X64=True
import jax
import jax.numpy as jnp
import jax.random as random
def logprob_analytic(x_samp, mu, covar):
nspec = x_samp.shape[0]
diff = x_samp - mu
x = jnp.linalg.solve(covar, diff)
_, lndetC = jnp.linalg.slogdet(covar)
lnL = -(jnp.dot(diff, x) + lndetC + nspec * jnp.log(2.0 * jnp.pi)) / 2.0
return lnL
ndim = 10
seed = 42
key = random.PRNGKey(seed)
key, subkey = random.split(key)
# Create a random upper triangular matrix with random entries off the diagonal
L = random.normal(subkey, (ndim, ndim))
L = L.at[jnp.diag_indices_from(L)].set(jnp.sqrt(0.1 * jnp.exp(L[jnp.diag_indices_from(L)])))
L = L.at[jnp.triu_indices_from(L,k=1)].set(0.0)
# Cholesky decomposition of covariance
covar = jnp.matmul(L, L.T)
# Mean is zero, generate one sample
mu = jnp.zeros(ndim)
nsamp = 1
key, subkey = random.split(key)
x_samp = random.multivariate_normal(subkey, mu, covar, shape=(nsamp,), method='svd').squeeze()
# Evaluate the logprob in two ways
lnP_ana = logprob_analytic(x_samp, mu, covar)
lnP_jax = jax.scipy.stats.multivariate_normal.logpdf(x_samp, mu, covar)
assert jnp.isclose(lnP_ana, lnP_jax)
What jax/jaxlib version are you using?
Jax version 0.3.13
Which accelerator(s) are you using?
- [X] CPU
- [ ] GPU
- [ ] TPU
Additional System Info
Python 3.9.7, MacOS M1 ARM,
I tracked down the issue to our logpdf using a Cholesky decomposition, which for some reason doesn't succeed in this case. You code uses a more general (and expensive) linear solver. I'm not a numerical stability expert so I'm not sure what's the best way to proceed. @hawkinsp might be able to hint at some improvements here.
To be fair, the old scipy.stats.multivariate_normal.logpdf also fails in this case. It does however fault with a singular matrix warning, but if you set allow_singular=True it continues but gives bad results. I think the issue here is accumulation of roundoff error when you perform cholesky decomposition for large matrices. It is clear from slogdet and how the covar was constructed that it is not singular. A better implementation here would switch to the more robust linear solver implementation for possibly (numerically) singular covariance matrices.