numpyro
numpyro copied to clipboard
[FR] MaskTransform
I'm currently working on a project where I'm partially conditioning a large set of latent variables. Here's a toy example:
dat = dist.Normal(10, 1).sample(random.PRNGKey(0), (100,))
mask = jnp.concatenate([jnp.ones(50, jnp.bool_), jnp.zeros(50, jnp.bool_)])
def model():
with numpyro.plate('n_obs', len(dat)):
a = numpyro.sample('a', dist.Normal(0, 3), obs=jnp.full(len(dat), 9), obs_mask=mask)
b = numpyro.sample('b', dist.Normal(0, 3))
numpyro.sample('dat', dist.Normal(a + b, 1), obs=dat)
As discussed in #2772, this quickly becomes inefficient for HMC methods, because it samples from the complete latent vector under the hood at every step, even though half the values do not contribute to the likelihood.
Can we support a minimal version of a MaskTransform
that converts a partially masked site into a completely observed site and a completely unobserved site for more efficient inference? That is, convert this:
with numpyro.plate('n_obs', len(dat)):
a = numpyro.sample('a', dist.Normal(0, 3), obs=jnp.full(len(dat), 9), obs_mask=mask)
into this:
with numpyro.plate('_a_observed', jnp.sum(mask)):
a = numpyro.sample('a', dist.Normal(0, 3), obs=jnp.full(jnp.sum(mask), 9))
with numpyro.plate('_a_unobserved', len(mask) - jnp.sum(mask)):
a = numpyro.sample('a', dist.Normal(0, 3))
If we force obs_mask
to be static (i.e. not determined by upstream latent variables) for example, then I think this is still broadly applicable and fairly simple to implement.
An alternative would be to allow upstream latents to determine 'obs_mask' and force recompilation each time that happens (since the completely observed plate and completely unobserved plate would change size).
Hi @amifalk, I feel that your verbosed version is a good one. IMO it removes the unnecessary abstraction, which could confuse users.
I think we can close this issue because the verbose code seems to be easier to understand. For modeling, we encourage users to make it more interpretable.