probability icon indicating copy to clipboard operation
probability copied to clipboard

RelaxedOneHotCategorical has incorrect output for logits under 1 when logits are type of float32

Open umyta opened this issue 2 years ago • 0 comments

Hi, I was trying out the RelaxedOneHotCategorical function on tensorfloa_probability version 0.19.0. The following code gives me the incorrect distribution.

from tensorflow_probability.substrates import jax as tfp
temperature = 0.5
p = [0.1, 0.5, 0.4]
dist = tfp.distributions.RelaxedOneHotCategorical(temperature, logits=p)
from jax import random
dist.sample(seed=random.PRNGkey(0))

The expected behavior is "the 2nd class is the most likely be the largest component in samples". However, I got the reverse probability instead.

Array([0.9872344, 0.00204739, 0.01071812], dtype=float32)

This behavior disappears when the logits are > 1, or if we cast p to be float16.

Is this expected?

umyta avatar Jan 02 '23 20:01 umyta