probability icon indicating copy to clipboard operation
probability copied to clipboard

Dirichlet distribution sampling issue when jit_compile=True

Open LorenzoRimella opened this issue 2 years ago • 1 comments

It seems that some seeds produce nans when sampling from a Dirichlet distribution. Any idea why? Example script below that was tested on Google Colab.

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

dirichlet_lambda = tf.convert_to_tensor([2., 5., 0., 10., 10., 12., 10., 10., 1., 1.], dtype = tf.float32)
seed_s2 = tf.convert_to_tensor([-1012227931,  -757448172], dtype = tf.int32)
seed_s3 = tf.convert_to_tensor([-1012227931,  -757448170], dtype = tf.int32)

@tf.function(jit_compile = True)
def jitwhat(concentration, seed):
    theta_j_k = tfp.distributions.Dirichlet(concentration = concentration).sample((13, 10), seed = seed) #.sample(seed = seed_s2) #

    return theta_j_k

foo = jitwhat(dirichlet_lambda, seed_s2)
np.where(np.isnan(foo))

Note that the Dirichlet distribution is "degenerate" as it has one of the parameters that is zero. However generally the output from the sampling method is just a zero in the corresponding position, while with that specific seed it gives NaN.

LorenzoRimella avatar Feb 20 '24 16:02 LorenzoRimella

Verified as a potential bug. Colab here.

chrism0dwk avatar Feb 27 '24 10:02 chrism0dwk