numpyro
numpyro copied to clipboard
Support constraints.cat and CatTransform
Hello!
I have a custom multi-dimensional distribution where the support may be truncated along some dimensions. In terms of constraints, some dimensions will either be real, greater_than, less_than, or interval. I naively was then implementing the support as, e.g.:
ivl = constraints.interval([0., -jnp.inf, 5.], [jnp.inf, 0., 10.])
Right now, this is not really supported by the numpyro.distributions.constraints.Interval class because of how feasible_like() works, or how the scale is computed in the unconstrained transform. Would you be open to making these things inf-safe? So far I instead implemented a custom subclass InfSafeInterval(constraints._Interval) to support this, but thought I would check in on this. Thanks!