probability icon indicating copy to clipboard operation
probability copied to clipboard

Sampling from TruncatedNormal can yield NaN

Open georgematheos opened this issue 1 year ago • 2 comments

Example:

from tensorflow_probability.substrates import jax as tfp
tfp.distributions.TruncatedNormal(
    0.5382424, 0.05, 0.80921564, 0.86921564
).sample(seed=jax.random.PRNGKey(2))

returns NaN.

JAX version: 0.4.33. TFP version: 0.23.0.

georgematheos avatar Sep 30 '24 14:09 georgematheos

@derifatives indicated that tfp.TruncatedNormal.sample wraps jax.random.truncated_normal, here. (We may be misunderstanding when this function is called.)

However, note that jax.random.truncated_normal can be used to sample from the above truncated normal distribution without clear issues:

mean, std, minval, maxval = 0.5382424, 0.05, 0.80921564, 0.86921564
minval_centered, maxval_centered = (minval - mean) / std, (maxval - mean) / std
centered_sample = jax.random.truncated_normal(jax.random.PRNGKey(2), minval_centered, maxval_centered)
sample = centered_sample * std + mean
sample # = 0.80921566

georgematheos avatar Sep 30 '24 15:09 georgematheos

However, this does not always work:

mean, std, minval, maxval = 0.09121108, 0.1, 0.62490195, 0.6849019
minval_centered, maxval_centered = (minval - mean) / std, (maxval - mean) / std
centered_sample = jax.random.truncated_normal(jax.random.PRNGKey(2), minval_centered, maxval_centered)
sample = centered_sample * std + mean
sample # = NaN

(I am finding these strange seeming configurations of numbers by running a fairly complex probabilistic inference program I have that is sampling millions of times from TruncatedNormals, and then filtering the results to find where NaNs were generated.)

georgematheos avatar Sep 30 '24 15:09 georgematheos