Guide samples for AutoDelta don't contain deterministic variables
I noticed that samples from the AutoDelta guide don't contain deterministic transforms of random variables in the model. Here's a minimal example which reproduces the issue:
import numpyro
import numpyro.distributions as dist
from numpyro import optim
from numpyro.infer import SVI
from numpyro.infer.autoguide import AutoDelta
from numpyro.infer.elbo import Trace_ELBO
from jax import random
def model():
x = numpyro.sample('x', dist.Normal(0, 1).expand([10]))
numpyro.deterministic('x2', x**2)
guide = AutoDelta(model)
svi = SVI(model, guide, optim.Adam(0.01), Trace_ELBO())
svi_state = svi.init(random.PRNGKey(0))
svi_result = svi.run(random.PRNGKey(1), 1000)
guide_samples = guide.sample_posterior(random.PRNGKey(2), svi_result.params)
print(guide_samples['x2'])
Is this related to #946?
Currently, AutoDelta.sample_posterior does not return deterministic sites. To collect those sites, we need to run Predictive on the model given those posterior samples.
Since this is marked as a good first issue, I could give this a try as a first contribution.
To run Predictive on the model, I would pass the model's *args and **kwargs to AutoDelta.sample_posterior as well. My first attempt is something like this. If someone could take a brief look if this makes sense, before I create a pull request, this would be highly appreciated!
Hi @nikmich1, the solution looks good. To account for the non-trivial case of sample_shape, you can add batch_ndims=len(sample_shape) to the Predictive utility.
@fehiepsi Thank you very much for the feedback! I added the sample_shape as suggested.
While running the unit tests, I noticed that sample_posterior should only pick up deterministic and not observed variables to be consistent e.g. with AutoNormal. Does it make sense to only return variables with AutoDelta.prototype_trace[variable]["type"] == "deterministic" (something like this)?
LGTM. You can obtain deterministic sites in advance so that we can
- skip running Predictive if there are no deterministic sites
- specify return_sites keywork in the Predictive constructor.