numpyro icon indicating copy to clipboard operation
numpyro copied to clipboard

Support constraints.cat and CatTransform

Open adrn opened this issue 1 year ago • 4 comments

Hello!

I have a custom multi-dimensional distribution where the support may be truncated along some dimensions. In terms of constraints, some dimensions will either be real, greater_than, less_than, or interval. I naively was then implementing the support as, e.g.:

ivl = constraints.interval([0., -jnp.inf, 5.], [jnp.inf, 0., 10.])

Right now, this is not really supported by the numpyro.distributions.constraints.Interval class because of how feasible_like() works, or how the scale is computed in the unconstrained transform. Would you be open to making these things inf-safe? So far I instead implemented a custom subclass InfSafeInterval(constraints._Interval) to support this, but thought I would check in on this. Thanks!

adrn avatar Sep 30 '24 16:09 adrn

Hi @adrn, for parameters with different domains, it is better to split them out, e.g. ivl0, ivl1, ivl2 in your case. We don't plan to support mixed-domain support.

fehiepsi avatar Oct 04 '24 18:10 fehiepsi

Thanks for the response! That makes sense. But hm, it's possible what I want to do is not supported at the moment. For my custom distribution, some pairs of the parameters are not independent and so I can't split them out easily. And some may have real support, some interval, and some greater_than/less_than. This is probably a hack, but I ended up implementing a custom Constraint and Transform that can handle the joint and independent parameters and their supports separately...

adrn avatar Oct 04 '24 19:10 adrn

If you want to build a custom joint distribution then cat constraints and CatTransform might be helpful https://pytorch.org/docs/stable/distributions.html#torch.distributions.transforms.CatTransform We can modify the title and make this a feature request I guess?

fehiepsi avatar Oct 04 '24 21:10 fehiepsi

We did something like this with the ConstraintCollection here: https://github.com/compmem/cognax/blob/main/cognax/joint/gaussian_copula.py - may be a helpful resource.

amifalk avatar Nov 03 '24 13:11 amifalk