numpyro
numpyro copied to clipboard
Add Entropy/Mode methods to discrete distributions
Solves #1696 Adds the following:
-
entropy
method for: Bernoulli/Categorical/DiscreteUniform/Geometric -
mode
property for: Bernoulli/Categorical/Binomial/Poisson/Geometric -
name
property for all distributions and explanation for NotImplementedError.
Current Status: Passes all local tests CPU and GPU (tested on JAX 0.4.23 and CUDA 12.2) Added dependencies: None
@fehiepsi Hello! Whenever you want and you are ready take a look. Thank you!
Hi @stergiosba, it seems that using -qlogq / p - logp
is pretty stable
import jax
import jax.numpy as jnp
logits = jnp.array([-100., -80., -60., -40., -20., 0., 20., 40., 60.])
logq = -jax.nn.softplus(logits)
logp = -jax.nn.softplus(-logits)
p = jax.scipy.special.expit(logits)
# probs = clam_probs(self.probs)
p_clip = jnp.clip(p, a_min=jnp.finfo(p).tiny)
-(1 - p) * logq / p_clip - logp
gives
Array([1.0000000e+02, 8.1000000e+01, 6.1000000e+01, 4.1000000e+01,
2.1000000e+01, 1.3862944e+00, 2.0611537e-09, 4.2483541e-18,
8.7565109e-27], dtype=float32)
while the actual values are
from decimal import Decimal
for i in logits:
l = Decimal(str(i))
p = 1 / (1 + (-l).exp())
print("entropy", -((1 - p) * (1 - p).ln() + p * p.ln()) / p)
entropy 1E+2
entropy 80.00000000000000000000000002
entropy 61.00496650303780216962340234
entropy 41.00000000000197983295132689
entropy 21.00000000103057681049624664
entropy 1.386294361119890618834464243
entropy 4.328422607333389151485388142E-8
entropy 1.741825244552915897262840401E-16
entropy 5.487531564015271267712575141E-25
Hi @stergiosba, we will release numpyro 0.14 in a few days. Do you want to incorporate this feature into it?