probability
probability copied to clipboard
nan encountered with tfp.substrates.jax.distributions.WishartTriL
Hi,
I'm trying to sample from a Wishart distribution using tfp.substrates.jax and run into nans. I'm not sure why. Here's code to reproduce the error:
import jax.numpy as jnp
from jax import random
import tensorflow_probability as tfp
from jax.config import config
config.update("jax_enable_x64", True)
config.update("jax_debug_nans", True)
dtype = jnp.float64
nc = 13
k = 5
jax_wishart = tfp.substrates.jax.distributions.WishartTriL(df=nc+1, scale_tril=jnp.eye(nc, dtype=dtype))
key = random.PRNGKey(0)
key, subkey = random.split(key)
jax_wishart_samples = jax_wishart.sample(seed=subkey, sample_shape = [10**k])
Thanks in advance.