support for custom `support` in `TransformedDistribution`
Feature Summary
Currently, MixtureGeneral supports a custom support argument, as discussed in this issue comment. It would be helpful to extend similar functionality to TransformedDistribution.
Why is this needed?
In some use cases, I employ custom Transform objects that are well-defined mathematically and invertible. However, these transforms don't inherently enforce the desired output support constraints. As a result, samples from the transformed distribution can fall outside the intended codomain, which leads to inconsistencies in model behavior and inference.
A concrete example is the following transformation from a primary mass $m_1$ and mass ratio $q \in (0, 1]$ to component masses $(m_1, m_2)$, where $m_2 = m_1 q$ and $m_2 \le m_1$:
# Copyright 2023 The GWKokab Authors
# SPDX-License-Identifier: Apache-2.0
class PrimaryMassAndMassRatioToComponentMassesTransform(Transform):
r"""Transforms a primary mass and mass ratio to component masses.
.. math::
f: (m_1, q)\to (m_1, m_1q)
.. math::
f^{-1}: (m_1, m_2)\to (m_1, m_2/m_1)
"""
domain = constraints.independent(
constraints.interval(
jnp.zeros((2,)), jnp.array([jnp.finfo(jnp.result_type(float)).max, 1.0])
),
1,
)
r""":math:`\mathcal{D}(f) = \mathbb{R}^2_+\times[0, 1]`"""
codomain = positive_decreasing_vector
r""":math:`\mathcal{C}(f)=\{(m_1, m_2)\in\mathbb{R}^2_+\mid m_1\geq m_2>0\}`"""
def __call__(self, x: Array):
m1, q = jnp.unstack(x, axis=-1)
m2 = jnp.multiply(m1, q)
m1m2 = jnp.stack((m1, m2), axis=-1)
return m1m2
def _inverse(self, y: Array):
m1, m2 = jnp.unstack(y, axis=-1)
q = mass_ratio(m2=m2, m1=m1)
m1q = jnp.stack((m1, q), axis=-1)
return m1q
def log_abs_det_jacobian(self, x: Array, y: Array, intermediates=None):
r"""
.. math::
\ln\left(|\mathrm{det}(J_f)|\right) = \ln(|m_1|)
"""
m1 = x[..., 0]
return jnp.log(jnp.abs(m1))
def tree_flatten(self):
return (), ((), dict())
def __eq__(self, other):
if not isinstance(other, PrimaryMassAndMassRatioToComponentMassesTransform):
return False
return self.domain == other.domain
class _PositiveDecreasingVector(_SingletonConstraint):
r"""Constrain values to be positive and decreasing, i.e. :math:`\forall i<j, x_i
\geq x_j`.
"""
event_dim = 1
def __call__(self, x):
return decreasing_vector.check(x) & independent(positive, 1).check(x)
def feasible_like(self, prototype):
return jnp.ones(prototype.shape, dtype=prototype.dtype)
def tree_flatten(self):
return (), ((), dict())
def __eq__(self, other):
return isinstance(other, _PositiveDecreasingVector)
positive_decreasing_vector = _PositiveDecreasingVector()
While this transformation is mathematically correct, TransformedDistribution does not currently enforce the correct constraint, i.e., $m_{\mathrm{min}}\leq m_2 \leq m_1 \leq m_{\mathrm{max}}$, which leads to invalid samples unless post-processing or custom sampling is added. The condition $m_1,m_2\in[m_{\mathrm{min}}, m_{\mathrm{max}}]$ is not a part of the transformation; therefore, allowing an explicit support argument would enable this constraint to be enforced directly, just as it is done in MixtureGeneral.
Hi @Qazalbash, for user code, it might be cleaner to define a new transform distribution with a custom support. WDYT?
The current implementation of TransformedDistribution transforms the log_prob, sample, cdf and icdf methods, but it doesn't transform the support. Either we should add this functionality of transforming the support or we should allow to pass the custom support (just like we did in MixtureGeneral distribution).
Because the example I have shared above is mathematically valid transform and the boundary constraint are part of the base distribution, but when I use the support of transformed distribution it uses the codomain of the transform (which is partially correct).
I meant to define
class A(TransformedDistribution):
support = ...
I understood that I was just explaining my thoughts!
Related: #1989
@dylanhmorris
IMO, something like this would solve both problems!
def support(self, value: ArrayLike) -> ArrayLike:
mask = jnp.ones_like(value, dtype=bool)
y = value
for transform in reversed(self.transforms):
x = transform.inv(y)
mask = jnp.logical_and(mask, transform.domain.support(x))
y = x
mask = jnp.logical_and(mask, self.base_dist.support(y))
return mask