pyro
pyro copied to clipboard
Add Helper Function for evaluating log_likelihood
Issue Description
Currently training some bayesian neural networks using HMC, would it be useful to include the calculation of log likelihood as a helper function? probably something like this using conditioning and trace?
def log_likelihood(model, posterior_samples, x, y):
log_likelihoods = []
sample_count = next(iter(posterior_samples.values())).shape[0]
for i in range(sample_count):
# Set the parameters of the model to the values in the i-th sample
conditioned_model = pyro.condition(
model, data={k: torch.tensor(v[i]) for k, v in posterior_samples.items()}
)
# Compute the log likelihood of the data given these parameters
trace = pyro.poutine.trace(conditioned_model).get_trace( # type: ignore
torch.from_numpy(x).to(torch.float32), torch.from_numpy(y).to(torch.float32)
)
log_likelihoods.append(trace.log_prob_sum())
# Average the log likelihoods over all samples
return torch.stack(log_likelihoods).mean()
I believe the above code computes the posterior predictive log density, which includes both prior and likelihood. In the past, when I've computed log-likelhood, I've manually masked out the prior sites. I'm unsure whether it's practical to automatically mask out prior sites in a way that is correct for reparametrization and other auxiliary variables.
Maybe a first step could be adding a log-likelihood computation to a couple existing tutorials, then seeing if there's a general implementation (that is e.g. batchable)?
@fritzo thanks for pointing out my mistake, much appreciated! Was just wondering how did you mask out your prior sites systematically (I'm very much new to Pyro)? You're right, it might be worth wrtiting up a couple of tutorials fot the log-likelihood computation, do you have any recommendations of where to start?
how did you mask out your prior sites systematically?
I've enclosed the top of a hierarchical model in a boolean poutine.mask(mask=___), e.g.
def example_model(data, include_prior: bool = True):
# Sample top level variables from the prior.
with poutine.mask(mask=include_prior):
loc = pyro.sample("loc", Normal(0, 1))
scale = pyro.sample("scale", LogNormal(0, 1))
# Observe data.
pyro.sample("data", Normal(loc, scale), obs=data)
do you have any recommendations of where to start?
Gosh there are over 50 tutorials on https://pyro.ai/examples . You might pick a domain you're interested in and add a section at the end. Then "likelihood" should still show up in search results.
@fritzo Thanks for this. Great, I can add them in the tutorial, might be useful!