flowjax
flowjax copied to clipboard
masked autoregressive flow with mixed transformer types
I am looking into a modification of a regular masked autoregressive flow where the base distribution is an N-dimensional uniform and the first variable does not get transformed, while the rest of the variables get transformed via a rational quadratic spline. I have removed the shuffling in the masked_autoregressive_flow
function via removing the _add_default_permute
, and modified the _flat_params_to_transformer
in the MaskedAutoregressive
class to apply an Identity transformer to the first dimension in the following way
def _flat_params_to_transformer(self, params: Array, y_dim=1):
"""Reshape to dim X params_per_dim, then vmap."""
dim = self.shape[-1]
transformer_params = jnp.reshape(params, (dim, -1))
transformer_params = transformer_params[y_dim:, :]
transformer = eqx.filter_vmap(self.transformer_constructor)(transformer_params)
return Concatenate(
[Identity((y_dim,)), Vmap(transformer, in_axes=eqx.if_array(0))]
)
My understanding is that in this way the masked_autoregressive_mlp
will still produce a set of spline parameters for the first variable, that then never get used, and that this should be harmless. My experiments seem to produce the expected results but I am not sure that this is the most efficient way to go about this or whether I am disregarding anything relevant, so would love to hear your opinion as to how to make the best use of your package. Thanks again for all the amazing work!
I think your approach works, but it would have a bit of extra overhead as like you said the masked autoregressive network will still produce a set of (unused) parameters for the identity transformed variables. If you wanted to avoid that, here's another possibility.
What I have done is wrap a masked autoregressive bijection that has dimension matching the dimensionality of the transformed variables, and cond_dim
matching the number of identity variables. In each method we pass in the identity transformed variables as conditioning variables to the masked autoregressive bijection. This should be an equivalent architecture, except avoiding the unnecessary computation.
from typing import ClassVar
from flowjax.bijections.bijection import AbstractBijection
from flowjax.bijections.masked_autoregressive import MaskedAutoregressive
class IdentityFirstMaskedAutoregressive(AbstractBijection):
masked_autoregressive: MaskedAutoregressive
identity_dim: int
shape: tuple[int, ...]
cond_shape: ClassVar[None] = None
def __init__(self, masked_autoregressive: MaskedAutoregressive):
self.masked_autoregressive = masked_autoregressive
self.identity_dim = masked_autoregressive.cond_shape[0]
self.shape = (self.identity_dim + self.masked_autoregressive.shape[0],)
def transform(self, x, condition=None):
y = self.masked_autoregressive.transform(
x[self.identity_dim :], condition=x[: self.identity_dim]
)
return x.at[self.identity_dim :].set(y)
def transform_and_log_det(self, x, condition=None):
y, log_det = self.masked_autoregressive.transform_and_log_det(
x[self.identity_dim :],
condition=x[: self.identity_dim],
)
return x.at[self.identity_dim :].set(y), log_det
def inverse(self, y, condition=None):
x = self.masked_autoregressive.inverse(
y[self.identity_dim :], condition=y[: self.identity_dim]
)
return y.at[self.identity_dim :].set(x)
def inverse_and_log_det(self, y, condition=None):
x, log_det = self.masked_autoregressive.inverse_and_log_det(
y[self.identity_dim :], condition=y[: self.identity_dim]
)
return y.at[self.identity_dim :].set(x), log_det
If you need to support a conditional version of this, then it should be possible with some concatenating and adjusting of shapes.
In general it could be possible to add support for a mix of transformer types, but e.g. if we assume we have a list of heterogeneous transformers then compilation speed might become an issue, as we can no longer just rely on vmap and would have to loop. Thanks for the support and let me know if you have any questions/issues!
This is a bit late, but another option is to defined individual bijections for the non-transformed variables and the remaining ones, and then stack them together into a single bijection:
import jax
import jax.numpy as jnp
from flowjax.bijections import Identity, RationalQuadraticSpline, MaskedAutoregressive, Concatenate
from flowjax.distributions import Uniform, Transformed
N = 5
base_dist = Uniform(minval = -jnp.ones(N), maxval = jnp.ones(N))
bijections = [
Identity(shape = (1,)),
MaskedAutoregressive(
key = jax.random.PRNGKey(0),
transformer = RationalQuadraticSpline(knots = 5, interval = 1.0),
dim = N - 1,
nn_width = 10,
nn_depth = 1,
),
]
# use Concatenate as it stacks bijections along an *existing* axis
bijection = Concatenate(bijections)
flow = Transformed(base_dist, bijection)
You could wrap this in a constructor with vmap
and permutations only within the N-1
dimensions, though maybe the Concatenate
has some overhead.
You can do that, but note that the transform of the transformed dimensions will be independent of the identity transformed variables if you do
It could be possible to support a transformer with shape/dimension matching the shape of the total bijection (rather than only scalar bijections), in which case you could Stack
/Concatenate
transformers as you please. The main issue is you need a reliable way to know/specify which parameters are involved in transforming which dimensions. The only way I can think of would be to force passing a pytree of ints with structure matching the parameters, specifying which output dimension the parameters pertain to (quite cumbersome).