sbi
sbi copied to clipboard
Likelihood-based inference with importance sampling
Would be great to have a tutorial that uses the ImportanceSamplingPosterior to perform asymptotically correct inference if the likelihood is available.
from torch import ones, eye
import torch
from torch.distributions import MultivariateNormal
from sbi.inference import SNPE, ImportanceSamplingPosterior
from sbi.utils import BoxUniform
from sbi.inference.potentials.base_potential import BasePotential
from sbi.analysis import pairplot, marginal_plot
class Simulator:
def __init__(self):
pass
def log_prob(self, theta, x):
return MultivariateNormal(theta, eye(2)).log_prob(x) + prior.log_prob(theta)
def sample(self, theta):
return theta + torch.randn((theta.shape))
class Potential(BasePotential):
allow_iid_x = False
def __init__(self, prior, x_o, **kwargs):
super().__init__(prior, x_o, **kwargs)
def __call__(self, theta, **kwargs):
return sim.log_prob(theta, self.x_o)
prior = BoxUniform(-5 * ones((2,)), 5 * ones((2,)))
sim = Simulator()
_ = torch.manual_seed(3)
theta = prior.sample((50,))
x = sim.sample(theta)
_ = torch.manual_seed(4)
inference = SNPE(prior=prior)
_ = inference.append_simulations(theta, x).train()
posterior = inference.build_posterior()
_ = torch.manual_seed(2)
theta_gt = prior.sample((5,))
observations = sim.sample(theta_gt)
print("observations.shape", observations.shape)
oversampling_factor = 128 # higher will be slower but more accurate
n_samples = 5000
non_corrected_samples_for_all_observations = []
corrected_samples_for_all_observations = []
true_samples = []
for obs in observations:
non_corrected_samples_for_all_observations.append(posterior.set_default_x(obs).sample((n_samples,)))
corrected_posterior = ImportanceSamplingPosterior(
potential_fn=Potential(prior=None, x_o=obs),
proposal=posterior.set_default_x(obs),
method="sir",
)
corrected_samples = corrected_posterior.sample((n_samples,), oversampling_factor=oversampling_factor)
corrected_samples_for_all_observations.append(corrected_samples)
gt_samples = MultivariateNormal(obs, eye(2)).sample((n_samples * 5,))
gt_samples = gt_samples[prior.support.check(gt_samples)][:n_samples]
true_samples.append(gt_samples)
for i in range(len(observations)):
fig, ax = marginal_plot(
[non_corrected_samples_for_all_observations[i], corrected_samples_for_all_observations[i], true_samples[i]],
limits=[[-5, 5], [-5, 5]],
points=theta_gt[i],
figsize=(5, 1.5),
diag="kde", # smooth histogram
)
ax[0][1].legend(["NPE", "Corrected", "Ground truth"], loc="upper right", bbox_to_anchor=[1.8, 1.0, 0.0, 0.0])