numpyro icon indicating copy to clipboard operation
numpyro copied to clipboard

Allow reprocess messages and add stop gradient and compute log prob handlers

Open fehiepsi opened this issue 3 years ago • 1 comments

This PR introduces a new field reprocess to sample sites. If it is available, after: process_msg and default_process_msg, we will apply reprocess_msg. I added the following two handlers to illustrate its usefulness.

detach

Different from https://github.com/pyro-ppl/pyro/issues/2854, here we only consider reparameterized samplers. This one is useful for quickly exploring the 'stick the landing' strategy - just simply do SVI(model, detach(guide, fields=("fn",))).

compute_log_prob

I have been wanted to have this handler for a while but it seems to be tricky in the past. This is an example where the new reprocess pattern is useful. In Pyro, we can simply do trace.compute_log_prob but numpyro trace is just a simple dictionary, so that method is not available. As a result, we usually resort to some utility like log_density or log_likelihood to compute log joint or log prob at particular sites. With this, we can simply call

trace(compute_log_prob(model, site_filter=lambda name, msg: name == "obs")).get_trace()["obs"]

to get log-likelihood. I believe with this handler, we can simplify much of the internal code that arises around log probabilities.

Also fixes a typo: reparametrized

fehiepsi avatar Feb 20 '22 02:02 fehiepsi

Seems like we can also achieve the same effect by using lazy value: https://pyro.ai/examples/effect_handlers.html

fehiepsi avatar Feb 22 '22 11:02 fehiepsi

Closed because this is not needed anymore.

fehiepsi avatar Jan 30 '23 14:01 fehiepsi