probability icon indicating copy to clipboard operation
probability copied to clipboard

nan encountered with tfp.substrates.jax.distributions.WishartTriL

Open saraelshawa opened this issue 3 years ago • 0 comments

Hi,

I'm trying to sample from a Wishart distribution using tfp.substrates.jax and run into nans. I'm not sure why. Here's code to reproduce the error:

import jax.numpy as jnp
from jax import random
import tensorflow_probability as tfp
from jax.config import config
config.update("jax_enable_x64", True)
config.update("jax_debug_nans", True)

dtype = jnp.float64
nc = 13
k = 5
jax_wishart = tfp.substrates.jax.distributions.WishartTriL(df=nc+1, scale_tril=jnp.eye(nc, dtype=dtype)) 
key = random.PRNGKey(0)
key, subkey = random.split(key)
jax_wishart_samples = jax_wishart.sample(seed=subkey, sample_shape = [10**k])

Thanks in advance.

saraelshawa avatar Feb 13 '23 15:02 saraelshawa