pyro icon indicating copy to clipboard operation
pyro copied to clipboard

[FR] Predictive with deterministic site in the guide

Open OlaRonning opened this issue 10 months ago • 6 comments

Hi,

I'm working on a project where we would like to access the output of an NN in the guide when using Predictive. We've implemented it using a deterministic site in the guide. The program boils down to the following.

import pyro
from pyro.infer import Predictive
from pyro.distributions import Normal
import torch

def model():
    pyro.deterministic('m_deter', torch.tensor(1.))
    pyro.sample('x', Normal(torch.zeros(()), torch.ones(())))

def guide():
    pyro.deterministic('g_deter', torch.tensor(1.))
    pyro.sample('x', Normal(torch.zeros(()), torch.ones(())))


Predictive(
  model=model, 
  guide=guide, 
  return_sites=('model_site', 'guide_site', 'x'), 
  num_samples=1)() # Includes m_deter but not g_deter

We would like for both m_deter and g_deter to be included. It looks like Predictive currently only considers model sites for return sites. Would it be possible to expand it so we can include deterministic sites from the guide?

OlaRonning avatar Apr 19 '24 11:04 OlaRonning

This looks reasonable to me. Since deterministic guide sites are usually ignored (e.g. in AutoGuides) I think we may want to gate this new behavior by an arg like return_deterministic_guide_sites: bool or something.

fritzo avatar Apr 20 '24 20:04 fritzo

A guard makes sense. I'll give an implementation a shot.

OlaRonning avatar Apr 21 '24 13:04 OlaRonning

@OlaRonning Can I help with this issue?

SarthakNikhal avatar Apr 23 '24 12:04 SarthakNikhal

@SarthakNikhal absolutely. Feel free to look at my WIP PR. I wrote the unittest relatively fast; you can probably develop a more suitable one.

OlaRonning avatar Apr 23 '24 13:04 OlaRonning

@OlaRonning Okay. What can I do better? Also, what other unit tests can you think of

SarthakNikhal avatar Apr 23 '24 14:04 SarthakNikhal

I would make the test cover four cases:

  1. return_deterministic is true and no return_sites. returned samples should include all deterministic sites in the guide.
  2. return_deterministic is true and return_sites includes one of two deterministic sites in the guide. returned samples should only include the deterministic guide site in return_sites.
  3. return_determininistic is true and there are no deterministic sites in the guide. returned samples should be the same as when return_deterministic is false.
  4. return_deterministic is false and there is a deterministic site in the guide. the returned samples should not include the deterministic site from the guide.

You'd probably want to check both that sites are included in the returned samples and that their values are as expected. I believe you can work directly on the aleatory_science branch.

OlaRonning avatar Apr 23 '24 17:04 OlaRonning