numpyro icon indicating copy to clipboard operation
numpyro copied to clipboard

Add Entropy/Mode methods to discrete distributions

Open stergiosba opened this issue 6 months ago • 3 comments

Solves #1696 Adds the following:

  • entropy method for: Bernoulli/Categorical/DiscreteUniform/Geometric
  • mode property for: Bernoulli/Categorical/Binomial/Poisson/Geometric
  • name property for all distributions and explanation for NotImplementedError.

Current Status: Passes all local tests CPU and GPU (tested on JAX 0.4.23 and CUDA 12.2) Added dependencies: None

stergiosba avatar Dec 23 '23 01:12 stergiosba

@fehiepsi Hello! Whenever you want and you are ready take a look. Thank you!

stergiosba avatar Dec 29 '23 08:12 stergiosba

Hi @stergiosba, it seems that using -qlogq / p - logp is pretty stable

import jax
import jax.numpy as jnp

logits = jnp.array([-100., -80., -60., -40., -20., 0., 20., 40., 60.])
logq = -jax.nn.softplus(logits)
logp = -jax.nn.softplus(-logits)
p = jax.scipy.special.expit(logits)
# probs = clam_probs(self.probs)
p_clip = jnp.clip(p, a_min=jnp.finfo(p).tiny)
-(1 - p) * logq / p_clip - logp

gives

Array([1.0000000e+02, 8.1000000e+01, 6.1000000e+01, 4.1000000e+01,
       2.1000000e+01, 1.3862944e+00, 2.0611537e-09, 4.2483541e-18,
       8.7565109e-27], dtype=float32)

while the actual values are

from decimal import Decimal
for i in logits:
    l = Decimal(str(i))
    p = 1 / (1 + (-l).exp())
    print("entropy", -((1 - p) * (1 - p).ln() + p * p.ln()) / p)
entropy 1E+2
entropy 80.00000000000000000000000002
entropy 61.00496650303780216962340234
entropy 41.00000000000197983295132689
entropy 21.00000000103057681049624664
entropy 1.386294361119890618834464243
entropy 4.328422607333389151485388142E-8
entropy 1.741825244552915897262840401E-16
entropy 5.487531564015271267712575141E-25

fehiepsi avatar Jan 01 '24 16:01 fehiepsi

Hi @stergiosba, we will release numpyro 0.14 in a few days. Do you want to incorporate this feature into it?

fehiepsi avatar Feb 16 '24 19:02 fehiepsi