sbi icon indicating copy to clipboard operation
sbi copied to clipboard

Likelihood-based inference with importance sampling

Open michaeldeistler opened this issue 7 months ago • 0 comments

Would be great to have a tutorial that uses the ImportanceSamplingPosterior to perform asymptotically correct inference if the likelihood is available.

from torch import ones, eye
import torch
from torch.distributions import MultivariateNormal

from sbi.inference import SNPE, ImportanceSamplingPosterior
from sbi.utils import BoxUniform
from sbi.inference.potentials.base_potential import BasePotential
from sbi.analysis import pairplot, marginal_plot


class Simulator:
    def __init__(self):
        pass

    def log_prob(self, theta, x):
        return MultivariateNormal(theta, eye(2)).log_prob(x) + prior.log_prob(theta)

    def sample(self, theta):
        return theta + torch.randn((theta.shape))


class Potential(BasePotential):
    allow_iid_x = False

    def __init__(self, prior, x_o, **kwargs):
        super().__init__(prior, x_o, **kwargs)

    def __call__(self, theta, **kwargs):
        return sim.log_prob(theta, self.x_o)


prior = BoxUniform(-5 * ones((2,)), 5 * ones((2,)))
sim = Simulator()

_ = torch.manual_seed(3)
theta = prior.sample((50,))
x = sim.sample(theta)

_ = torch.manual_seed(4)
inference = SNPE(prior=prior)
_ = inference.append_simulations(theta, x).train()
posterior = inference.build_posterior()

_ = torch.manual_seed(2)
theta_gt = prior.sample((5,))
observations = sim.sample(theta_gt)
print("observations.shape", observations.shape)


oversampling_factor = 128  # higher will be slower but more accurate
n_samples = 5000

non_corrected_samples_for_all_observations = []
corrected_samples_for_all_observations = []
true_samples = []
for obs in observations:
    non_corrected_samples_for_all_observations.append(posterior.set_default_x(obs).sample((n_samples,)))
    corrected_posterior = ImportanceSamplingPosterior(
        potential_fn=Potential(prior=None, x_o=obs),
        proposal=posterior.set_default_x(obs),
        method="sir",
    )
    corrected_samples = corrected_posterior.sample((n_samples,), oversampling_factor=oversampling_factor)
    corrected_samples_for_all_observations.append(corrected_samples)

    gt_samples = MultivariateNormal(obs, eye(2)).sample((n_samples * 5,))
    gt_samples = gt_samples[prior.support.check(gt_samples)][:n_samples]
    true_samples.append(gt_samples)


for i in range(len(observations)):
    fig, ax = marginal_plot(
        [non_corrected_samples_for_all_observations[i], corrected_samples_for_all_observations[i], true_samples[i]], 
        limits=[[-5, 5], [-5, 5]], 
        points=theta_gt[i], 
        figsize=(5, 1.5),
        diag="kde",  # smooth histogram
    )
    ax[0][1].legend(["NPE", "Corrected", "Ground truth"], loc="upper right", bbox_to_anchor=[1.8, 1.0, 0.0, 0.0])

michaeldeistler avatar Dec 18 '23 10:12 michaeldeistler