[FEATURE] Add zero-inflated distribution likelihood models
Is your feature request related to a current problem? Please describe. Zero-inflated distributions usefully describe many physical and business processes. And practically, supporting zero-inflation would also allow the simple extension of the many R_{>0} distributions to R{>=0}.
Describe proposed solution Implement zero-inflated versions of the torch distributions supported and add corresponding likelihood models. At first glace it does not appear that it would be too hard to do this with a mixin, eg
class ZeroInflationMixin(torch.distributions.Distribution):
def __init__(self, p, *args, **kwargs):
self.p, = broadcast_all(p)
super().__init__(*args, **kwargs)
def sample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
with torch.no_grad():
zeros = torch.bernoulli(self.p.expand(shape))
samples = super().sample(sample_shape)
# If zeros is 1, we return 0, otherwise we return the sample from the base distribution
return torch.where(zeros, torch.zeros_like(samples), samples)
def log_prob(self, value):
log_prob_zero_inf = torch.log(self.p)
log_prob_base_dist = super().log_prob(value) + torch.log1p(-self.p)
return torch.where(value == 0, log_prob_zero_inf, log_prob_base_dist)
And then
class ZeroInflatedPoisson(ZeroInflationMixin, torch.distributions.Poisson):
arg_constraints = {"p": constraints.unit_interval}
support = constraints.nonnegative_integer
""" and other distributions as desired """
These could then be implemented as likelihood models.
Describe potential alternatives Nothing comes to mind, though quantile loss can be used in the interim.
Thanks for this feature request @eschibli, it sounds indeed like a valuable addition to Darts.
I wonder whether we can get this to work without having to define a dedicated zero-inflated likelihood for each likelihood.
Something like below:
class ZeroInflatedLikelihood(TorchLikelihood):
def __init__(self, likelihood: TorchLikelihood, p):
"""
Zero-Inflated version of a given likelihood.
This class wraps a given likelihood and adds a zero-inflation component to it.
It is useful for modeling time series with many zeros, such as counts or discrete events.
Parameters
----------
likelihood
The base likelihood to be wrapped.
p
The probability of the zero-inflation component, i.e., the probability of the target being zero.
"""
# add the required logic to get the zero-inflation to work
It would take any Darts torch likelihood (all except QuantileRegression ?) and add the zero inflation component to it.
I haven't fully thought it through but I believe it should be possible.
Would you be interested in contributing?
Sure I'll see what I can do next week. Thanks Dennis.