numpyro icon indicating copy to clipboard operation
numpyro copied to clipboard

Guide samples for AutoDelta don't contain deterministic variables

Open fbartolic opened this issue 4 years ago • 1 comments

I noticed that samples from the AutoDelta guide don't contain deterministic transforms of random variables in the model. Here's a minimal example which reproduces the issue:

import numpyro
import numpyro.distributions as dist
from numpyro import optim
from numpyro.infer import SVI
from numpyro.infer.autoguide import AutoDelta
from numpyro.infer.elbo import Trace_ELBO
from jax import random

def model():
    x = numpyro.sample('x', dist.Normal(0, 1).expand([10]))
    numpyro.deterministic('x2', x**2)
    
guide = AutoDelta(model)
svi = SVI(model, guide, optim.Adam(0.01), Trace_ELBO())
svi_state = svi.init(random.PRNGKey(0))
svi_result = svi.run(random.PRNGKey(1), 1000)
guide_samples = guide.sample_posterior(random.PRNGKey(2), svi_result.params)
print(guide_samples['x2'])

Is this related to #946?

fbartolic avatar Mar 10 '21 22:03 fbartolic

Currently, AutoDelta.sample_posterior does not return deterministic sites. To collect those sites, we need to run Predictive on the model given those posterior samples.

fehiepsi avatar Mar 10 '21 22:03 fehiepsi

Since this is marked as a good first issue, I could give this a try as a first contribution.

To run Predictive on the model, I would pass the model's *args and **kwargs to AutoDelta.sample_posterior as well. My first attempt is something like this. If someone could take a brief look if this makes sense, before I create a pull request, this would be highly appreciated!

nikmich1 avatar Apr 17 '23 13:04 nikmich1

Hi @nikmich1, the solution looks good. To account for the non-trivial case of sample_shape, you can add batch_ndims=len(sample_shape) to the Predictive utility.

fehiepsi avatar Apr 18 '23 18:04 fehiepsi

@fehiepsi Thank you very much for the feedback! I added the sample_shape as suggested.

While running the unit tests, I noticed that sample_posterior should only pick up deterministic and not observed variables to be consistent e.g. with AutoNormal. Does it make sense to only return variables with AutoDelta.prototype_trace[variable]["type"] == "deterministic" (something like this)?

nikmich1 avatar Apr 27 '23 08:04 nikmich1

LGTM. You can obtain deterministic sites in advance so that we can

  • skip running Predictive if there are no deterministic sites
  • specify return_sites keywork in the Predictive constructor.

fehiepsi avatar May 01 '23 19:05 fehiepsi