numpyro
numpyro copied to clipboard
Sample from distribution without storing
I am currently working on a project where we embed a VAE-decoder inside a model. Accordingly, we need to sample z
s from a multivariate normal distribution, but we are not interested in the posterior of the z
s. Here is an example model:
def model(y=None):
var = numpyro.sample("variance", dist.HalfNormal())
ls = numpyro.sample("lengthscale", dist.HalfNormal())
z = numpyro.sample("z", dist.MultivariateNormal(jnp.zeros(2500), jnp.zeros(2500))) # <- want to sample but not store
y_hat = numpyro.deterministic("y_hat", vae.decode(jnp.array([*z, ls, var])))
sigma = numpyro.sample("sigma", dist.HalfNormal(0.1))
numpyro.sample("obs", dist.Normal(y_hat[mask], sigma), obs=y)
Currently, we are running inference on 50x50 grids with a z
dimension of 2500 (one z
per point in the grid), which means a standard model saves 2500 z
s per step. We never use these z
s and would like to prevent storing them to save memory and computation. We would greatly appreciate any advice!
I don't think we store the latent values. Could you elaborate?
Sorry for the late reply. I may be misunderstanding, but the problem is that we are sampling latent variables that are nuisance parameters, so we don't need estimates of their posteriors. Is using numpyro.sample
still the correct construct for latent nuisance parameters or is there a lighter weight sampling procedure, e.g. a pure jax method that might be more appropriate?
Are you using MCMC? There is collect_fields
to filter out variables that are not required. If you are using SVI, then we don't store latent variables during training.
I have the exact same issue !
@fehiepsi Could you elaborate on the use of collect_fields
? I can't find relevant entries in the docs
See e.g. this comment https://forum.pyro.ai/t/reducing-mcmc-memory-usage/5639/4?u=fehiepsi
I wasn't able to discern from that comment how to use collect_fields
. It isn't an argument to NUTS(...)
, MCMC(...)
, or mcmc.run(..., collect_fields=...)
. Where / how do you add collect fields, and is it just a list of variable names you want to keep? I'm using numpyro==0.13.2
. Thank you!
If my understanding is correct, the only way is to run the MCMC step by step and manually trace the parameters of interest
Sorry, my brain was not working when I sent the previous comment. The argument name is extra_fields
, not collect_fields
. There is a property named default_fields which will store the variables. I think we can enable an api to allow doing
mcmc = MCMC(NUTS(model))
mcmc.sampler.default_fields = ("z.foo", "z.bar")
following the changes in the forum comment (linked in my last comment).
I'm still a bit lost on how I might do this for nested arrays. For instance, I am running the following model on an "image" of satellite data and I've got 60 subgrids of size 50x50. I sample a random vector z
of 512 values for each subgrid for each sample. So, if I'm doing 1000 samples, that is 60 * 512 * 1000 values stored. However, I don't care about the posteriors of these values -- they are simply used to seed a generative model that I have inserted as a deterministic transformation (simulator.decode
in the model). What is the best way to ignore the posteriors of the 60 * 512 z
values?
def satellite_model(T=None):
sigma_T = numpyro.sample("sigma_T", dist.HalfNormal(10))
for b in range(num_subgrids):
z = numpyro.sample(f"z[{b}]", dist.Normal(0, 1).expand([z_dim]))
ls = numpyro.sample(f"ls[{b}]", dist.Beta(3, 6))
var = numpyro.sample(f"var[{b}]", dist.LogNormal(0, 1))
c = jnp.hstack([ls, var])
mu = simulator.apply(
{"params": params}, z, c, method=simulator.decode
).squeeze()
numpyro.sample(
f"T[{b}]",
dist.Normal(mu[non_nan_idx[b]], sigma_T),
obs=T[b][non_nan_idx[b]],
)
if you have a model with density p(x, y)
and y
is a "nuisance" variable in the sense that you don't care about it's posterior but you still want to integrate out the uncertainty associated with its unknown value it's still required to to do inference over y
since different y
slices of p(x, y)
lead to different conditional posteriors over x
and so there's no way around doing inference on y
.
of course to save memory you needn't actually save all the y
samples.
there's also another possibility in which you're not actually trying to do "proper inference" and maybe instead y
is fixed once at the beginning or sampled from a fixed distribution at each step in inference---but that's not doing proper inference over x
in the presence of uncertainty over y
.
We do need to do inference over z
, especially since we are using HMC and it will be calculated gradients over z
in the latent space, but we would prefer not to save all these samples to save memory. Is there a way to do this?
If you don't want to change the source code, then you can do
import jax
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
def model():
numpyro.sample("x", dist.Normal(0, 1))
numpyro.sample("y", dist.Normal(0, 1))
class CustomNUTS(NUTS):
def postprocess_fn(self, args, kwargs):
transform = super().postprocess_fn(args, kwargs)
def new_transform(z):
z = transform(z)
z.pop("x")
return z
return new_transform
mcmc = MCMC(CustomNUTS(model), num_warmup=10, num_samples=20)
mcmc.run(jax.random.PRNGKey(0))
mcmc.get_samples().keys()
But it's easy to support this feature. As outlined above, we can:
- allow users to change the default field
z
(in your case, you can replace it byz.foo
,z.bar
, where foo and bar are latent variables that you want to keep) - update the behavior of
collect_fields
(as in https://forum.pyro.ai/t/reducing-mcmc-memory-usage/5639/4?u=fehiepsi)
Let's keep this issue open in case a contributor wants to support this feature. You can use the above CustomNUTS
in the mean time.
I find I often have a pattern where my random variable is a nuisance variable, but some deterministic function of it is meaningful. In this case, the desired behavior is more so a function of the model than a function of the inference algorithm, so it's inconvenient to have to tamper with settings for every fit.
I would much prefer to have a flag in the numpyro.sample
function to toggle whether or not a site is collected during mcmc.
a = numpyro.sample('a_', dist.MultivariateNormal(jnp.zeros(2500), jnp.zeros(2500)), collect=False)
a = numpyro.deterministic('a', a*2)
What do you think @fehiepsi ?
Yes, we can add a field to the "infer"
keyword. But this requires us to update all MCMC kernels. I feel that supporting mcmc.sampler.default_fields = ("z.a",)
is simpler. What do you think?
It looks like all the samplers create a trace on initialization, most via initialize_model
. It should be easy to add a function to infer.util
that takes the trace and returns the default fields. Even though we would have to update each one, I don't think it would add much complexity. We would just need to add a setter method for default_fields
in the MCMCKernel
superclass and add one line to each kernel.
def _init_state(self, ...):
model_trace, ... = numpyro.infer.util.initialize_model(...)
self.default_fields = numpyro.infer.util.get_default_fields(model_trace)
Is this solution ok for you? If so, I would be happy to draft up a PR.
Yup, I think the solution looks good. Users can use either infer
or default_fields
.