pyro
pyro copied to clipboard
[FR] Predictive with deterministic site in the guide
Hi,
I'm working on a project where we would like to access the output of an NN in the guide when using Predictive
. We've implemented it using a deterministic site in the guide. The program boils down to the following.
import pyro
from pyro.infer import Predictive
from pyro.distributions import Normal
import torch
def model():
pyro.deterministic('m_deter', torch.tensor(1.))
pyro.sample('x', Normal(torch.zeros(()), torch.ones(())))
def guide():
pyro.deterministic('g_deter', torch.tensor(1.))
pyro.sample('x', Normal(torch.zeros(()), torch.ones(())))
Predictive(
model=model,
guide=guide,
return_sites=('model_site', 'guide_site', 'x'),
num_samples=1)() # Includes m_deter but not g_deter
We would like for both m_deter
and g_deter
to be included. It looks like Predictive currently only considers model sites for return sites. Would it be possible to expand it so we can include deterministic sites from the guide?
This looks reasonable to me. Since deterministic guide sites are usually ignored (e.g. in AutoGuides) I think we may want to gate this new behavior by an arg like return_deterministic_guide_sites: bool
or something.
A guard makes sense. I'll give an implementation a shot.
@OlaRonning Can I help with this issue?
@SarthakNikhal absolutely. Feel free to look at my WIP PR. I wrote the unittest relatively fast; you can probably develop a more suitable one.
@OlaRonning Okay. What can I do better? Also, what other unit tests can you think of
I would make the test cover four cases:
- return_deterministic is true and no return_sites. returned samples should include all deterministic sites in the guide.
- return_deterministic is true and return_sites includes one of two deterministic sites in the guide. returned samples should only include the deterministic guide site in return_sites.
- return_determininistic is true and there are no deterministic sites in the guide. returned samples should be the same as when return_deterministic is false.
- return_deterministic is false and there is a deterministic site in the guide. the returned samples should not include the deterministic site from the guide.
You'd probably want to check both that sites are included in the returned samples and that their values are as expected. I believe you can work directly on the aleatory_science branch.