numpyro
numpyro copied to clipboard
Distributions Entropy Method
Hello guys, I come from the Tensorflow Distributions world and was looking for a lightweight alternative and was pleasantly surprised to see that Pyro is available for Jax via your amazing work.
I have implemented the PPO algorithm for some of my DRL problems and inside the loss function the entropy of a Categorical distribution is needed. I saw that the CategoricalLogits
class does not have an entropy
method contrary to those found in TFP and Distrax (from DeepMind). Is there a different, and possibly, more streamlined way to get it in numpyro without an external function that has the following form:
def entropy(distr: numpyro.distributions.discrete.CategoricalLogits):
logits = distr.logits
return -jnp.sum(jax.nn.softmax(logits)*jax.nn.log_softmax(logits))
Is this a design choice? I have implemented an entropy method on the local numpyro I am using for my projects but possible others want this little feature added.
Anyways let me know what you think.
Cheers!
Yeah, it would be great if we have the entropy method. So you can do d.entropy()
where d is an logits categorical distribution.
Changed topic title since all distributions (or most of them) do not have an entropy method.
@stergiosba Hi. Are you working on this issue ? Or else I want to do it.
I am actively working on it yes. Let's colab if you want @yayami3
@stergiosba Thanks for the offer. Which distribution are you targeting? I wrote a draft about the foundational classes and tests here
I am working on discrete ones now. I added entropy as a method and not a property so it matches other python modules like Distrax and TFP.
I have done Categorical and Bernoulli. I double check with Distrax and TFP to get the same results as they do. Only one comment by looking at your test cases:
@pytest.mark.parametrize(
"jax_dist, sp_dist, params",
[
T(dist.BernoulliProbs, 0.2),
T(dist.BernoulliProbs, np.array([0.2, 0.7])),
T(dist.BernoulliLogits, np.array([-1.0, 3.0])),
],
)
Make sure you cover edge cases like exploding logits.
For the Bernoulli distribution you used xlogy
this automatically handles the aforementioned problem so the following works:
def entropy(self):
return -xlogy(self.probs, self.probs) - xlog1py(1 - self.probs, -self.probs)
But I wanted to make the explicit check and did this for example:
def entropy(self):
"""Calculates the entropy of the Bernoulli distribution with probability p.
H(p,q)=-qlog(q)-plog(p) where q=1-p.
With extra care for p=0 and p=1.
Returns:
entropy: The entropy of the Bernoulli distribution.
"""
probs0 = _to_probs_bernoulli(-1.0 * self.logits)
probs1 = self.probs
log_probs0 = -jax.nn.softplus(self.logits)
log_probs1 = -jax.nn.softplus(-1.0 * self.logits)
# Could make this into a function if we need it elsewhere.
qlogq = jnp.where(probs0 == 0.0, 0.0, probs0 * log_probs0)
plogp = jnp.where(probs1 == 0.0, 0.0, probs1 * log_probs1)
return -qlogq - plogp
I compared the performance of both solutions and is the same. Also for some reason xlogy
cuts off at the 8th decimal point but that is minor.
I don't know which style is better. Maybe @fehiepsi can give his take on this.
I am also adding a mode property for the distributions.
@stergiosba
Thanks for your comment !
I think it's a good idea, but there are other modules using xlogy
as well, and it feels like there is a lack of consistency.
I felt that it was more important to make people aware of the purpose and effects of xlogy
.
Let's just wait for @fehiepsi anyway.
I think you can clip y and use xlogy. I remember than grad needs to be computed correctly at the extreme points. I don't have strong opinion on the style though.
Great catch there @fehiepsi
There is an issue with the gradients when using the xlogy
.
I set up a toy problem to test gradients and the xlogy
implementation failed at the extreme case where p =1 by returned nan.
I was able to fix the nan by adding lax.stop_gradient
for the case where p=1 as:
def entropy(self):
"""Calculates the entropy of the Bernoulli distribution with probability p.
H(p,q)=-qlog(q)-plog(p) where q=1-p.
With extra care for p=0 and p=1.
Returns:
entropy: The entropy of the Bernoulli distribution.
"""
probs = lax.stop_gradient(self.probs)
return -xlogy(probs, probs) - xlog1py(1 - probs, -probs)
Same with clipping and using xlogy:
def entropy(self):
"""Calculates the entropy of the Bernoulli distribution with probability p.
H(p,q)=-qlog(q)-plog(p) where q=1-p.
With extra care for p=0 and p=1.
Returns:
entropy: The entropy of the Bernoulli distribution.
"""
probs = lax.stop_gradient(self.probs)
# Could make this into a function if we need it elsewhere.
qlogq = jnp.where(probs == 0.0, 0.0, xlog1py(1.0-probs, -probs))
plogp = jnp.where(probs == 1.0, 0.0, xlogy(probs, probs))
return -qlogq - plogp
Just for the record the first entropy calculation I provided was based on Distrax's code and it had no problems with gradients "out of the box".
But we can go forward with the xlogy
function as it works with the addition of stop gradients.
I think it is better to do: probs_positive = clip(probs, a_min=tiny)
and compute xlogy(probs, probs_positive)
. similar to probs_less_than_one
. We need grad of the entropy.
Yeah I was blind, I see the issue. This is the clipping:
def entropy(self, eps=1e-9):
"""Calculates the entropy of the Bernoulli distribution with probability p.
H(p,q)=-qlog(q)-plog(p) where q=1-p.
With extra care for p=0 and p=1.
Returns:
entropy: The entropy of the Bernoulli distribution.
"""
probs = jnp.clip(self.probs, eps, 1.0 - eps)
return -xlogy(probs, probs) - xlog1py(1.0 - probs, -probs)
Clipping works for the gradients but inherently has errors. For example we fail to pass the testcase with big negative logit.
The following, although not the most beautiful, works for everything so I vote to go with it.
def entropy(self):
"""Calculates the entropy of the Bernoulli distribution with probability p.
H(p,q)=-qlog(q)-plog(p) where q=1-p.
With extra care for p=0 and p=1.
Returns:
entropy: The entropy of the Bernoulli distribution.
"""
q = _to_probs_bernoulli(-1.0 * self.logits)
p = self.probs
logq = -jax.nn.softplus(self.logits)
logp = -jax.nn.softplus(-1.0 * self.logits)
# Could make this into a function if we need it elsewhere.
qlogq = jnp.where(q == 0.0, 0.0, q * logq)
plogp = jnp.where(p == 0.0, 0.0, p * logp)
return -qlogq - plogp
Ideas about what to return when logits are very negative in a Geometric distribution.
As you can see from the code below we need to divide with p and when logits are very negative p=sigmoid(logit)=0.
TFP and PyTorch return nan in this case and DIstrax does not have a Geometric distribution.
def entropy(self):
"""Calculates the entropy of the Geometric distribution with probability p.
H(p,q)=[-qlog(q)-plog(p)]*1/p where q=1-p.
With extra care for p=0 and p=1.
Returns:
entropy: The entropy of the Geometric distribution.
"""
q = _to_probs_bernoulli(-1.0 * self.logits)
p = self.probs
logq = -jax.nn.softplus(self.logits)
logp = -jax.nn.softplus(-1.0 * self.logits)
# Could make this into a function if we need it elsewhere.
qlogq = jnp.where(q == 0.0, 0.0, q * logq)
plogp = jnp.where(p == 0.0, 0.0, p * logp)
return (-qlogq - plogp) * 1.0 / p
You can divide implicitly (rather than directly). e.g. I think you can use ( I have not checked yet)
- for bernoulli:
(1 - probs) * logits - log_sigmoid(logits)
- for geometric:
-(1 + jnp.exp(-logits)) * log_sigmoid(-logits) - log_sigmoid(logits)
Edit: ignore me, exp(-logits) can be very large
Ok I will add some tests for the Probs
versions of the distributions and submit a PR for the discrete distributions and you can review it there. Thanks for the help!
Just an update: we have entropy methods for continuous distributions in https://github.com/pyro-ppl/numpyro/pull/1787 and https://github.com/pyro-ppl/numpyro/pull/1800. The methods for discrete distributions are wip in #1706 (it is currently closed so any contribution is welcomed!).
I think #1787 also covers most of the discrete distributions.
@tillahoffmann Do you want to address the rest? There is a subtle numerical issue in Geometric entropy IIRC. I can also take a stab at it and would like to have your help in reviewing.
@fehiepsi, do you recall what the numerical issue in Geometric
is and if it's the one with probs
or logits
parameterization?
Based on which of the test_entropy_samples
tests are skipped, it looks like the following implementations are still missing. I don't know off the top of my head which of these have analytic entropies but can certainly have a look.
- [ ]
AsymmetricLaplace
- [ ]
AsymmetricLaplaceQuantile
- [ ]
BetaBinomial
- [ ]
BinomialLogits
- [ ]
BinomialProbs
- [ ]
CAR
- [ ]
DirichletMultinomial
- [ ]
EulerMaruyama
- [ ]
FoldedNormal
- [ ]
GammaPoisson
- [ ]
GaussianCopulaBeta
- [ ]
GaussianRandomWalk
- [ ]
Gompertz
- [ ]
Gumbel
- [ ]
HalfCauchy
- [ ]
HalfNormal
- [ ]
Kumaraswamy
- [ ]
LKJ
- [ ]
LKJCholesky
- [ ]
MatrixNormal
- [ ]
MultinomialLogits
- [ ]
MultinomialProbs
- [ ]
MultivariateStudentT
- [ ]
NegativeBinomial2
- [ ]
NegativeBinomialLogits
- [ ]
NegativeBinomialProbs
- [ ]
OrderedLogistic
- [ ]
Poisson
- [ ]
ProjectedNormal
- [ ]
RelaxedBernoulliLogits
- [ ]
SineBivariateVonMises
- [ ]
SineSkewedUniform
- [ ]
SineSkewedVonMises
- [ ]
SineSkewedVonMisesBatched
- [ ]
SoftLaplace
- [ ]
SparsePoisson
- [ ]
TwoSidedTruncatedDistribution
- [ ]
VonMises
- [ ]
WishartCholesky
- [ ]
ZeroInflatedPoisson
- [ ]
ZeroInflatedPoissonLogits
- [ ]
ZeroSumNormal
- [ ]
_Gaussian2DMixture
- [ ]
_GaussianMixture
- [ ]
_General2DMixture
- [ ]
_GeneralMixture
- [ ]
_ImproperWrapper
- [ ]
_SparseCAR
- [ ]
_TruncatedCauchy
- [ ]
_TruncatedNormal
The formula is H(p,q)=[-qlog(q)-plog(p)]*1/p where q=1-p
. It is prefered to work with logits instead of probs. The issue is to maintain precision when p near 0 and p near 1. I came up with a solution in https://github.com/pyro-ppl/numpyro/pull/1706#issuecomment-1873382888 but not sure if we can do better. By the way, I think we dont need entropy for all distributions - just the ones that users requested.
Makes sense. See #1852 for a fix. Shall we use the list of items above to track which distributions users have requested an implementation for?