pyro copied to clipboard
Maybe add epsilon to RelaxedOneHotCategorical to prevent underflow
I've noticed that pyro.distributions.RelaxedOneHotCategorical tends to underflow pretty dramatically if you decrease the temperature below 0.3 or so with many categories. I've been adding a slight modification to the rsample function of the ExpRelaxedCategorical class it's built on. Just wanted to post this in case you want to consider this (maybe hacky) fix to make this distribution work with pyro support constraints.
modified from here :
class ExpRelaxedCategorical(Distribution):
Creates a ExpRelaxedCategorical parameterized by
:attr:`temperature`, and either :attr:`probs` or :attr:`logits` (but not both).
Returns the log of a point in the simplex. Based on the interface to
Implementation based on [1].
See also: :func:`torch.distributions.OneHotCategorical`
temperature (Tensor): relaxation temperature
probs (Tensor): event probabilities
logits (Tensor): unnormalized log probability for each event
[1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables
(Maddison et al, 2017)
[2] Categorical Reparametrization with Gumbel-Softmax
(Jang et al, 2017)
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
support = (
) # The true support is actually a submanifold of this.
has_rsample = True
def __init__(self, temperature, probs=None, logits=None, validate_args=None):
self._categorical = Categorical(probs, logits)
self.temperature = temperature
batch_shape = self._categorical.batch_shape
event_shape = self._categorical.param_shape[-1:]
super().__init__(batch_shape, event_shape, validate_args=validate_args)
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(ExpRelaxedCategorical, _instance)
batch_shape = torch.Size(batch_shape)
new.temperature = self.temperature
new._categorical = self._categorical.expand(batch_shape)
super(ExpRelaxedCategorical, new).__init__(
batch_shape, self.event_shape, validate_args=False
new._validate_args = self._validate_args
return new
def _new(self, *args, **kwargs):
return self._categorical._new(*args, **kwargs)
def param_shape(self):
return self._categorical.param_shape
def logits(self):
return self._categorical.logits
def probs(self):
return self._categorical.probs
def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
uniforms = clamp_probs(
torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device)
gumbels = -((-(uniforms.log())).log())
scores = (self.logits + gumbels) / self.temperature
#######*add a floor to prevent underflow*#########
#could also clamp_probs
outs = scores - scores.logsumexp(dim=-1, keepdim=True)
outs = outs.exp()
outs = (outs/outs.sum(1,keepdim=True)).log()
return outs
def log_prob(self, value):
K = self._categorical._num_events
if self._validate_args:
logits, value = broadcast_all(self.logits, value)
log_scale = torch.full_like(
self.temperature, float(K)
).lgamma() - self.temperature.log().mul(-(K - 1))
score = logits - value.mul(self.temperature)
score = (score - score.logsumexp(dim=-1, keepdim=True)).sum(-1)
return score + log_scale
FYI: I also took a stab at fixing the straightthroughcategorical, this could still use some work but it works for me where the previous RelaxedCategoricalStraightThrough would not train as part of an GMM-VAE
class RelaxedQuantizeCategorical(torch.autograd.Function):
temperature = None # Default temperature
epsilon = 1e-10 # Default epsilon
def set_temperature(new_temperature):
RelaxedQuantizeCategorical.temperature = new_temperature
def set_epsilon(new_epsilon):
RelaxedQuantizeCategorical.epsilon = new_epsilon
def forward(ctx, soft_value):
temperature = float(RelaxedQuantizeCategorical.temperature)
epsilon = RelaxedQuantizeCategorical.epsilon
uniforms = clamp_probs(
torch.rand(soft_value.shape, dtype=soft_value.dtype, device=soft_value.device)
gumbels = -((-(uniforms.log())).log())
scores = (soft_value + gumbels) / temperature
outs = scores - scores.logsumexp(dim=-1, keepdim=True)
outs = outs.exp()
outs = outs + epsilon # Use the class variable epsilon
hard_value = (outs / outs.sum(1, keepdim=True)).log()
hard_value._unquantize = soft_value
return hard_value
def backward(ctx, grad):
return grad
class ExpRelaxedCategoricalStraightThrough(Distribution):
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
support = (
) # The true support is actually a submanifold of this.
has_rsample = True
def __init__(self, temperature, probs=None, logits=None, validate_args=None, epsilon=1e-10):
self._categorical = Categorical(probs, logits)
self.temperature = temperature
batch_shape = self._categorical.batch_shape
event_shape = self._categorical.param_shape[-1:]
super().__init__(batch_shape, event_shape, validate_args=validate_args)
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(ExpRelaxedCategorical, _instance)
batch_shape = torch.Size(batch_shape)
new.temperature = self.temperature
new._categorical = self._categorical.expand(batch_shape)
super(ExpRelaxedCategorical, new).__init__(
batch_shape, self.event_shape, validate_args=False
new._validate_args = self._validate_args
return new
def _new(self, *args, **kwargs):
return self._categorical._new(*args, **kwargs)
def param_shape(self):
return self._categorical.param_shape
def logits(self):
return self._categorical.logits
def probs(self):
return self._categorical.probs
def rsample(self, sample_shape=torch.Size()):
return outs
def log_prob(self, value):
value = getattr(value, "_unquantize", value)
K = self._categorical._num_events
if self._validate_args:
logits, value = broadcast_all(self.logits, value)
score = logits
score = (score - score.logsumexp(dim=-1, keepdim=True)).sum(-1)
return score
class SafeAndRelaxedOneHotCategoricalStraightThrough(TransformedDistribution,TorchDistributionMixin):
#Don't understand why these were broken (doesn't call straighthrough rsample in pyro)?
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
support = constraints.simplex
has_rsample = True
def __init__(self, temperature, probs=None, logits=None, validate_args=None):
base_dist = ExpRelaxedCategoricalStraightThrough(
temperature, probs, logits, validate_args=validate_args
super().__init__(base_dist, ExpTransform(), validate_args=validate_args)
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(RelaxedOneHotCategorical, _instance)
return super().expand(batch_shape, _instance=new)
def temperature(self):
return self.base_dist.temperature
def logits(self):
return self.base_dist.logits
def probs(self):
return self.base_dist.probs
Hi @mtvector, I think our general design principle with distributions is to make them hackable with decent defaults. In this case I'd lean towards letting users add their own epsilon in a custom distribution class. In my own projects I often have one or two custom distributions for each data science project. What do you think of a simple patched distribution, just for your project?
from pyro.distributions import ExpRelaxedCategorical
class SafeExpRelaxedCategorical(ExpRelaxedCategorical):
epsilon = 1e-10
def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
uniforms = clamp_probs(
torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device)
gumbels = -((-(uniforms.log())).log())
scores = (self.logits + gumbels) / self.temperature
#could also clamp_probs
outs = scores - scores.logsumexp(dim=-1, keepdim=True)
outs = outs.exp()
outs = outs + self.epsilon # prevent underflow
outs = (outs / outs.sum(1, keepdim=True)).log()
return outs
Actually I often find that (1) clamping is safer than adding, and (2) it's best to use torch.finfo(-).tiny
rather than a hard-coded epsilon. So you might customize
class SafeExpRelaxedCategorical2(ExpRelaxedCategorical):
def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
uniforms = clamp_probs(
torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device)
gumbels = -((-(uniforms.log())).log())
scores = (self.logits + gumbels) / self.temperature
#could also clamp_probs
outs = scores - scores.logsumexp(dim=-1, keepdim=True)
outs = outs.exp()
outs = outs.clamp(min=torch.finfo(outs.dtype).tiny)
outs = (outs / outs.sum(1, keepdim=True)).log()
return outs
Hi @fritzo, I agree in principle, you're right about the the hackability as well as using the proper epsilon or torch tiny, still working on my coding modularity :). I do think it's important to fix the default though, I used pyro for two years and thought the RelaxedCategorical was totally unusable because it seems to fail in the following:
import pyro
import torch
import pyro.distributions as dist
def model(logits):
def guide(logits):
optim = pyro.optim.Adam({"lr": 0.1})
elbo = pyro.infer.Trace_ELBO()
svi = pyro.infer.SVI(model, guide, optim, loss=elbo)
for i in range(10):
loss = svi.step(logits)
Giving the error due to underflow:
.../pyro/lib/python3.11/site-packages/pyro/poutine/ UserWarning: Encountered NaN: log_prob_sum at site 'cat_sample'
You're right about the fix, for instance your first resolves the issue with the underflow in a more elegant way than what I proposed:
import pyro.distributions
from torch.distributions.relaxed_categorical import ExpRelaxedCategorical
from pyro.distributions.torch_distribution import TorchDistributionMixin
from pyro.distributions import TransformedDistribution
class SafeExpRelaxedCategorical(ExpRelaxedCategorical):
epsilon = 1e-10
def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
uniforms = clamp_probs(
torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device)
gumbels = -((-(uniforms.log())).log())
scores = (self.logits + gumbels) / self.temperature
#could also clamp_probs
outs = scores - scores.logsumexp(dim=-1, keepdim=True)
outs = outs.exp()
outs = outs + self.epsilon # prevent underflow
outs = (outs / outs.sum(1, keepdim=True)).log()
return outs
class SafeRelaxedOneHotCategorical(TransformedDistribution,TorchDistributionMixin):
Creates a RelaxedOneHotCategorical distribution parametrized by
:attr:`temperature`, and either :attr:`probs` or :attr:`logits`.
This is a relaxed version of the :class:`OneHotCategorical` distribution, so
its samples are on simplex, and are reparametrizable.
>>> # xdoctest: +IGNORE_WANT("non-deterinistic")
>>> m = RelaxedOneHotCategorical(torch.tensor([2.2]),
... torch.tensor([0.1, 0.2, 0.3, 0.4]))
>>> m.sample()
tensor([ 0.1294, 0.2324, 0.3859, 0.2523])
temperature (Tensor): relaxation temperature
probs (Tensor): event probabilities
logits (Tensor): unnormalized log probability for each event
arg_constraints = {'probs': constraints.simplex,
'logits': constraints.real_vector}
support = constraints.simplex
has_rsample = True
def __init__(self, temperature, probs=None, logits=None, validate_args=None):
base_dist = SafeExpRelaxedCategorical(temperature, probs, logits, validate_args=validate_args)
super().__init__(base_dist, ExpTransform(), validate_args=validate_args)
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(RelaxedOneHotCategorical, _instance)
return super().expand(batch_shape, _instance=new)
def temperature(self):
return self.base_dist.temperature
def logits(self):
return self.base_dist.logits
def probs(self):
return self.base_dist.probs
def model(logits):
def guide(logits):
optim = pyro.optim.Adam({"lr": 0.1})
elbo = pyro.infer.Trace_ELBO()
svi = pyro.infer.SVI(model, guide, optim, loss=elbo)
for i in range(10):
loss = svi.step(logits)
Which gives no error, like my SafeAndRelaxedOneHotCategoricalStraightThrough above
So, yeah, it seems like the default for RelaxedOneHotCategorical should use one of these SafeExpRelaxedCategorical bases you've proposed here?