numpyro
numpyro copied to clipboard
Allow reprocess messages and add stop gradient and compute log prob handlers
This PR introduces a new field reprocess to sample sites. If it is available, after: process_msg and default_process_msg, we will apply reprocess_msg. I added the following two handlers to illustrate its usefulness.
detach
Different from https://github.com/pyro-ppl/pyro/issues/2854, here we only consider reparameterized samplers. This one is useful for quickly exploring the 'stick the landing' strategy - just simply do SVI(model, detach(guide, fields=("fn",))).
compute_log_prob
I have been wanted to have this handler for a while but it seems to be tricky in the past. This is an example where the new reprocess pattern is useful. In Pyro, we can simply do trace.compute_log_prob but numpyro trace is just a simple dictionary, so that method is not available. As a result, we usually resort to some utility like log_density or log_likelihood to compute log joint or log prob at particular sites. With this, we can simply call
trace(compute_log_prob(model, site_filter=lambda name, msg: name == "obs")).get_trace()["obs"]
to get log-likelihood. I believe with this handler, we can simplify much of the internal code that arises around log probabilities.
Also fixes a typo: reparametrized
Seems like we can also achieve the same effect by using lazy value: https://pyro.ai/examples/effect_handlers.html
Closed because this is not needed anymore.