probability
probability copied to clipboard
RelaxedOneHotCategorical has incorrect output for logits under 1 when logits are type of float32
Hi, I was trying out the RelaxedOneHotCategorical function on tensorfloa_probability version 0.19.0. The following code gives me the incorrect distribution.
from tensorflow_probability.substrates import jax as tfp
temperature = 0.5
p = [0.1, 0.5, 0.4]
dist = tfp.distributions.RelaxedOneHotCategorical(temperature, logits=p)
from jax import random
dist.sample(seed=random.PRNGkey(0))
The expected behavior is "the 2nd class is the most likely be the largest component in samples". However, I got the reverse probability instead.
Array([0.9872344, 0.00204739, 0.01071812], dtype=float32)
This behavior disappears when the logits are > 1, or if we cast p to be float16.
Is this expected?