numpyro
numpyro copied to clipboard
Support constraints.cat and CatTransform
Hello!
I have a custom multi-dimensional distribution where the support may be truncated along some dimensions. In terms of constraints, some dimensions will either be real, greater_than, less_than, or interval. I naively was then implementing the support as, e.g.:
ivl = constraints.interval([0., -jnp.inf, 5.], [jnp.inf, 0., 10.])
Right now, this is not really supported by the numpyro.distributions.constraints.Interval class because of how feasible_like() works, or how the scale is computed in the unconstrained transform. Would you be open to making these things inf-safe? So far I instead implemented a custom subclass InfSafeInterval(constraints._Interval) to support this, but thought I would check in on this. Thanks!
Hi @adrn, for parameters with different domains, it is better to split them out, e.g. ivl0, ivl1, ivl2 in your case. We don't plan to support mixed-domain support.
Thanks for the response! That makes sense. But hm, it's possible what I want to do is not supported at the moment. For my custom distribution, some pairs of the parameters are not independent and so I can't split them out easily. And some may have real support, some interval, and some greater_than/less_than. This is probably a hack, but I ended up implementing a custom Constraint and Transform that can handle the joint and independent parameters and their supports separately...
If you want to build a custom joint distribution then cat constraints and CatTransform might be helpful https://pytorch.org/docs/stable/distributions.html#torch.distributions.transforms.CatTransform We can modify the title and make this a feature request I guess?
We did something like this with the ConstraintCollection here: https://github.com/compmem/cognax/blob/main/cognax/joint/gaussian_copula.py - may be a helpful resource.