numpyro icon indicating copy to clipboard operation
numpyro copied to clipboard

Distributions Entropy Method

Open stergiosba opened this issue 1 year ago • 15 comments

Hello guys, I come from the Tensorflow Distributions world and was looking for a lightweight alternative and was pleasantly surprised to see that Pyro is available for Jax via your amazing work.

I have implemented the PPO algorithm for some of my DRL problems and inside the loss function the entropy of a Categorical distribution is needed. I saw that the CategoricalLogits class does not have an entropy method contrary to those found in TFP and Distrax (from DeepMind). Is there a different, and possibly, more streamlined way to get it in numpyro without an external function that has the following form:

def entropy(distr: numpyro.distributions.discrete.CategoricalLogits):
        logits = distr.logits
        return -jnp.sum(jax.nn.softmax(logits)*jax.nn.log_softmax(logits))

Is this a design choice? I have implemented an entropy method on the local numpyro I am using for my projects but possible others want this little feature added.

Anyways let me know what you think.

Cheers!

stergiosba avatar Dec 12 '23 23:12 stergiosba

Yeah, it would be great if we have the entropy method. So you can do d.entropy() where d is an logits categorical distribution.

fehiepsi avatar Dec 13 '23 01:12 fehiepsi

Changed topic title since all distributions (or most of them) do not have an entropy method.

stergiosba avatar Dec 13 '23 20:12 stergiosba

@stergiosba Hi. Are you working on this issue ? Or else I want to do it.

yayami3 avatar Dec 18 '23 13:12 yayami3

I am actively working on it yes. Let's colab if you want @yayami3

stergiosba avatar Dec 18 '23 20:12 stergiosba

@stergiosba Thanks for the offer. Which distribution are you targeting? I wrote a draft about the foundational classes and tests here

yayami3 avatar Dec 18 '23 23:12 yayami3

I am working on discrete ones now. I added entropy as a method and not a property so it matches other python modules like Distrax and TFP.

I have done Categorical and Bernoulli. I double check with Distrax and TFP to get the same results as they do. Only one comment by looking at your test cases:

@pytest.mark.parametrize(
    "jax_dist, sp_dist, params",
    [
        T(dist.BernoulliProbs, 0.2),
        T(dist.BernoulliProbs, np.array([0.2, 0.7])),
        T(dist.BernoulliLogits, np.array([-1.0, 3.0])),
    ],
)

Make sure you cover edge cases like exploding logits.

For the Bernoulli distribution you used xlogy this automatically handles the aforementioned problem so the following works:

def entropy(self):
    return -xlogy(self.probs, self.probs) - xlog1py(1 - self.probs, -self.probs)

But I wanted to make the explicit check and did this for example:

def entropy(self):
    """Calculates the entropy of the Bernoulli distribution with probability p.
    H(p,q)=-qlog(q)-plog(p) where q=1-p.
    With extra care for p=0 and p=1.

    Returns:
        entropy: The entropy of the Bernoulli distribution.
    """
    probs0 = _to_probs_bernoulli(-1.0 * self.logits)
    probs1 = self.probs
    log_probs0 = -jax.nn.softplus(self.logits)
    log_probs1 = -jax.nn.softplus(-1.0 * self.logits)
    
    # Could make this into a function if we need it elsewhere.
    qlogq = jnp.where(probs0 == 0.0, 0.0, probs0 * log_probs0)
    plogp = jnp.where(probs1 == 0.0, 0.0, probs1 * log_probs1)
    return -qlogq - plogp

I compared the performance of both solutions and is the same. Also for some reason xlogy cuts off at the 8th decimal point but that is minor.

I don't know which style is better. Maybe @fehiepsi can give his take on this.

stergiosba avatar Dec 19 '23 00:12 stergiosba

I am also adding a mode property for the distributions.

stergiosba avatar Dec 19 '23 01:12 stergiosba

@stergiosba Thanks for your comment ! I think it's a good idea, but there are other modules using xlogy as well, and it feels like there is a lack of consistency. I felt that it was more important to make people aware of the purpose and effects of xlogy. Let's just wait for @fehiepsi anyway.

yayami3 avatar Dec 19 '23 12:12 yayami3

I think you can clip y and use xlogy. I remember than grad needs to be computed correctly at the extreme points. I don't have strong opinion on the style though.

fehiepsi avatar Dec 19 '23 12:12 fehiepsi

Great catch there @fehiepsi

There is an issue with the gradients when using the xlogy. I set up a toy problem to test gradients and the xlogy implementation failed at the extreme case where p =1 by returned nan.

I was able to fix the nan by adding lax.stop_gradient for the case where p=1 as:

def entropy(self):
    """Calculates the entropy of the Bernoulli distribution with probability p.
    H(p,q)=-qlog(q)-plog(p) where q=1-p.
    With extra care for p=0 and p=1.

    Returns:
        entropy: The entropy of the Bernoulli distribution.
    """
    probs = lax.stop_gradient(self.probs)
    return -xlogy(probs, probs) - xlog1py(1 - probs, -probs)

Same with clipping and using xlogy:

def entropy(self):
    """Calculates the entropy of the Bernoulli distribution with probability p.
    H(p,q)=-qlog(q)-plog(p) where q=1-p.
    With extra care for p=0 and p=1.

    Returns:
        entropy: The entropy of the Bernoulli distribution.
    """
    probs = lax.stop_gradient(self.probs)
    
    # Could make this into a function if we need it elsewhere.
    qlogq = jnp.where(probs == 0.0, 0.0, xlog1py(1.0-probs, -probs))
    plogp = jnp.where(probs == 1.0, 0.0, xlogy(probs, probs))
    return -qlogq - plogp

Just for the record the first entropy calculation I provided was based on Distrax's code and it had no problems with gradients "out of the box".

But we can go forward with the xlogy function as it works with the addition of stop gradients.

stergiosba avatar Dec 19 '23 22:12 stergiosba

I think it is better to do: probs_positive = clip(probs, a_min=tiny) and compute xlogy(probs, probs_positive). similar to probs_less_than_one. We need grad of the entropy.

fehiepsi avatar Dec 19 '23 22:12 fehiepsi

Yeah I was blind, I see the issue. This is the clipping:

def entropy(self, eps=1e-9):
    """Calculates the entropy of the Bernoulli distribution with probability p.
    H(p,q)=-qlog(q)-plog(p) where q=1-p.
    With extra care for p=0 and p=1.

    Returns:
        entropy: The entropy of the Bernoulli distribution.
    """
    probs = jnp.clip(self.probs, eps, 1.0 - eps)
    return -xlogy(probs, probs) - xlog1py(1.0 - probs, -probs)

Clipping works for the gradients but inherently has errors. For example we fail to pass the testcase with big negative logit.

The following, although not the most beautiful, works for everything so I vote to go with it.

def entropy(self):
    """Calculates the entropy of the Bernoulli distribution with probability p.
    H(p,q)=-qlog(q)-plog(p) where q=1-p.
    With extra care for p=0 and p=1.

    Returns:
        entropy: The entropy of the Bernoulli distribution.
    """
    q = _to_probs_bernoulli(-1.0 * self.logits)
    p = self.probs
    logq = -jax.nn.softplus(self.logits)
    logp = -jax.nn.softplus(-1.0 * self.logits)
    
    # Could make this into a function if we need it elsewhere.
    qlogq = jnp.where(q == 0.0, 0.0, q * logq)
    plogp = jnp.where(p == 0.0, 0.0, p * logp)
    return -qlogq - plogp

stergiosba avatar Dec 19 '23 23:12 stergiosba

Ideas about what to return when logits are very negative in a Geometric distribution.

As you can see from the code below we need to divide with p and when logits are very negative p=sigmoid(logit)=0.

TFP and PyTorch return nan in this case and DIstrax does not have a Geometric distribution.

def entropy(self):
    """Calculates the entropy of the Geometric distribution with probability p.
    H(p,q)=[-qlog(q)-plog(p)]*1/p where q=1-p.
    With extra care for p=0 and p=1.

    Returns:
        entropy: The entropy of the Geometric distribution.
    """
    q = _to_probs_bernoulli(-1.0 * self.logits)
    p = self.probs
    logq = -jax.nn.softplus(self.logits)
    logp = -jax.nn.softplus(-1.0 * self.logits)

    # Could make this into a function if we need it elsewhere.
    qlogq = jnp.where(q == 0.0, 0.0, q * logq)
    plogp = jnp.where(p == 0.0, 0.0, p * logp)
    return (-qlogq - plogp) * 1.0 / p

stergiosba avatar Dec 20 '23 16:12 stergiosba

You can divide implicitly (rather than directly). e.g. I think you can use ( I have not checked yet)

  • for bernoulli: (1 - probs) * logits - log_sigmoid(logits)
  • for geometric: -(1 + jnp.exp(-logits)) * log_sigmoid(-logits) - log_sigmoid(logits)

Edit: ignore me, exp(-logits) can be very large

fehiepsi avatar Dec 20 '23 20:12 fehiepsi

Ok I will add some tests for the Probs versions of the distributions and submit a PR for the discrete distributions and you can review it there. Thanks for the help!

stergiosba avatar Dec 20 '23 21:12 stergiosba

Just an update: we have entropy methods for continuous distributions in https://github.com/pyro-ppl/numpyro/pull/1787 and https://github.com/pyro-ppl/numpyro/pull/1800. The methods for discrete distributions are wip in #1706 (it is currently closed so any contribution is welcomed!).

fehiepsi avatar Aug 01 '24 19:08 fehiepsi

I think #1787 also covers most of the discrete distributions.

tillahoffmann avatar Aug 02 '24 15:08 tillahoffmann

@tillahoffmann Do you want to address the rest? There is a subtle numerical issue in Geometric entropy IIRC. I can also take a stab at it and would like to have your help in reviewing.

fehiepsi avatar Aug 10 '24 22:08 fehiepsi

@fehiepsi, do you recall what the numerical issue in Geometric is and if it's the one with probs or logits parameterization?

Based on which of the test_entropy_samples tests are skipped, it looks like the following implementations are still missing. I don't know off the top of my head which of these have analytic entropies but can certainly have a look.

  • [ ] AsymmetricLaplace
  • [ ] AsymmetricLaplaceQuantile
  • [ ] BetaBinomial
  • [ ] BinomialLogits
  • [ ] BinomialProbs
  • [ ] CAR
  • [ ] DirichletMultinomial
  • [ ] EulerMaruyama
  • [ ] FoldedNormal
  • [ ] GammaPoisson
  • [ ] GaussianCopulaBeta
  • [ ] GaussianRandomWalk
  • [ ] Gompertz
  • [ ] Gumbel
  • [ ] HalfCauchy
  • [ ] HalfNormal
  • [ ] Kumaraswamy
  • [ ] LKJ
  • [ ] LKJCholesky
  • [ ] MatrixNormal
  • [ ] MultinomialLogits
  • [ ] MultinomialProbs
  • [ ] MultivariateStudentT
  • [ ] NegativeBinomial2
  • [ ] NegativeBinomialLogits
  • [ ] NegativeBinomialProbs
  • [ ] OrderedLogistic
  • [ ] Poisson
  • [ ] ProjectedNormal
  • [ ] RelaxedBernoulliLogits
  • [ ] SineBivariateVonMises
  • [ ] SineSkewedUniform
  • [ ] SineSkewedVonMises
  • [ ] SineSkewedVonMisesBatched
  • [ ] SoftLaplace
  • [ ] SparsePoisson
  • [ ] TwoSidedTruncatedDistribution
  • [ ] VonMises
  • [ ] WishartCholesky
  • [ ] ZeroInflatedPoisson
  • [ ] ZeroInflatedPoissonLogits
  • [ ] ZeroSumNormal
  • [ ] _Gaussian2DMixture
  • [ ] _GaussianMixture
  • [ ] _General2DMixture
  • [ ] _GeneralMixture
  • [ ] _ImproperWrapper
  • [ ] _SparseCAR
  • [ ] _TruncatedCauchy
  • [ ] _TruncatedNormal

tillahoffmann avatar Aug 12 '24 17:08 tillahoffmann

The formula is H(p,q)=[-qlog(q)-plog(p)]*1/p where q=1-p. It is prefered to work with logits instead of probs. The issue is to maintain precision when p near 0 and p near 1. I came up with a solution in https://github.com/pyro-ppl/numpyro/pull/1706#issuecomment-1873382888 but not sure if we can do better. By the way, I think we dont need entropy for all distributions - just the ones that users requested.

fehiepsi avatar Aug 18 '24 11:08 fehiepsi

Makes sense. See #1852 for a fix. Shall we use the list of items above to track which distributions users have requested an implementation for?

tillahoffmann avatar Aug 19 '24 15:08 tillahoffmann