numpyro icon indicating copy to clipboard operation
numpyro copied to clipboard

support for custom `support` in `TransformedDistribution`

Open Qazalbash opened this issue 8 months ago • 6 comments

Feature Summary

Currently, MixtureGeneral supports a custom support argument, as discussed in this issue comment. It would be helpful to extend similar functionality to TransformedDistribution.

Why is this needed?

In some use cases, I employ custom Transform objects that are well-defined mathematically and invertible. However, these transforms don't inherently enforce the desired output support constraints. As a result, samples from the transformed distribution can fall outside the intended codomain, which leads to inconsistencies in model behavior and inference.

A concrete example is the following transformation from a primary mass $m_1$ and mass ratio $q \in (0, 1]$ to component masses $(m_1, m_2)$, where $m_2 = m_1 q$ and $m_2 \le m_1$:

# Copyright 2023 The GWKokab Authors
# SPDX-License-Identifier: Apache-2.0


class PrimaryMassAndMassRatioToComponentMassesTransform(Transform):
    r"""Transforms a primary mass and mass ratio to component masses.

    .. math::
        f: (m_1, q)\to (m_1, m_1q)

    .. math::
        f^{-1}: (m_1, m_2)\to (m_1, m_2/m_1)
    """

    domain = constraints.independent(
        constraints.interval(
            jnp.zeros((2,)), jnp.array([jnp.finfo(jnp.result_type(float)).max, 1.0])
        ),
        1,
    )
    r""":math:`\mathcal{D}(f) = \mathbb{R}^2_+\times[0, 1]`"""
    codomain = positive_decreasing_vector
    r""":math:`\mathcal{C}(f)=\{(m_1, m_2)\in\mathbb{R}^2_+\mid m_1\geq m_2>0\}`"""

    def __call__(self, x: Array):
        m1, q = jnp.unstack(x, axis=-1)
        m2 = jnp.multiply(m1, q)
        m1m2 = jnp.stack((m1, m2), axis=-1)
        return m1m2

    def _inverse(self, y: Array):
        m1, m2 = jnp.unstack(y, axis=-1)
        q = mass_ratio(m2=m2, m1=m1)
        m1q = jnp.stack((m1, q), axis=-1)
        return m1q

    def log_abs_det_jacobian(self, x: Array, y: Array, intermediates=None):
        r"""
        .. math::
            \ln\left(|\mathrm{det}(J_f)|\right) = \ln(|m_1|)
        """
        m1 = x[..., 0]
        return jnp.log(jnp.abs(m1))

    def tree_flatten(self):
        return (), ((), dict())

    def __eq__(self, other):
        if not isinstance(other, PrimaryMassAndMassRatioToComponentMassesTransform):
            return False
        return self.domain == other.domain


class _PositiveDecreasingVector(_SingletonConstraint):
    r"""Constrain values to be positive and decreasing, i.e. :math:`\forall i<j, x_i
    \geq x_j`.
    """

    event_dim = 1

    def __call__(self, x):
        return decreasing_vector.check(x) & independent(positive, 1).check(x)

    def feasible_like(self, prototype):
        return jnp.ones(prototype.shape, dtype=prototype.dtype)

    def tree_flatten(self):
        return (), ((), dict())

    def __eq__(self, other):
        return isinstance(other, _PositiveDecreasingVector)

positive_decreasing_vector = _PositiveDecreasingVector()

While this transformation is mathematically correct, TransformedDistribution does not currently enforce the correct constraint, i.e., $m_{\mathrm{min}}\leq m_2 \leq m_1 \leq m_{\mathrm{max}}$, which leads to invalid samples unless post-processing or custom sampling is added. The condition $m_1,m_2\in[m_{\mathrm{min}}, m_{\mathrm{max}}]$ is not a part of the transformation; therefore, allowing an explicit support argument would enable this constraint to be enforced directly, just as it is done in MixtureGeneral.

Qazalbash avatar Jul 13 '25 23:07 Qazalbash

Hi @Qazalbash, for user code, it might be cleaner to define a new transform distribution with a custom support. WDYT?

fehiepsi avatar Jul 17 '25 05:07 fehiepsi

The current implementation of TransformedDistribution transforms the log_prob, sample, cdf and icdf methods, but it doesn't transform the support. Either we should add this functionality of transforming the support or we should allow to pass the custom support (just like we did in MixtureGeneral distribution).

Because the example I have shared above is mathematically valid transform and the boundary constraint are part of the base distribution, but when I use the support of transformed distribution it uses the codomain of the transform (which is partially correct).

Qazalbash avatar Jul 17 '25 07:07 Qazalbash

I meant to define

class A(TransformedDistribution):
  support = ...

fehiepsi avatar Jul 17 '25 08:07 fehiepsi

I understood that I was just explaining my thoughts!

Qazalbash avatar Jul 17 '25 19:07 Qazalbash

Related: #1989

dylanhmorris avatar Jul 22 '25 22:07 dylanhmorris

@dylanhmorris

IMO, something like this would solve both problems!

def support(self, value: ArrayLike) -> ArrayLike:
    mask = jnp.ones_like(value, dtype=bool)
    y = value
    for transform in reversed(self.transforms):
        x = transform.inv(y)
        mask = jnp.logical_and(mask, transform.domain.support(x))
        y = x
    mask = jnp.logical_and(mask, self.base_dist.support(y))
    return mask

Qazalbash avatar Jul 23 '25 03:07 Qazalbash