sbi
sbi copied to clipboard
MCMC sampler generates `inf` values
MCMC method slice_np_vectorized with parameters
mcmc_parameters = dict(
num_chains=50,
thin=5,
warmup_steps=5,
init_strategy="proposal",
)
is generating inf values. This is related to issue #1037 and seems to depend on the warmup_steps parameter.
To Reproduce
import matplotlib.pyplot as plt
import torch
from torch import eye, zeros
from torch.distributions import MultivariateNormal
from sbi.analysis import pairplot
from sbi.inference import SNLE, simulate_for_sbi
from sbi.simulators.linear_gaussian import (
linear_gaussian,
)
from sbi.utils.user_input_checks import (
process_prior,
process_simulator,
)
# Seeding
torch.manual_seed(1);
# Gaussian simulator
theta_dim = 2
x_dim = theta_dim
# likelihood_mean will be likelihood_shift+theta
likelihood_shift = -1.0 * zeros(x_dim)
likelihood_cov = 0.3 * eye(x_dim)
prior_mean = zeros(theta_dim)
prior_cov = eye(theta_dim)
prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov)
# Define Gaussian simulator
prior, num_parameters, prior_returns_numpy = process_prior(prior)
simulator = process_simulator(
lambda theta: linear_gaussian(theta, likelihood_shift, likelihood_cov),
prior,
prior_returns_numpy,
)
theta_o = zeros(1, theta_dim)
xo = theta_o.repeat(20, 1)
# Train SNLE.
inferer = SNLE(prior, show_progress_bars=True, density_estimator="mdn")
theta, x = simulate_for_sbi(simulator, prior, 10000, simulation_batch_size=1000)
inferer.append_simulations(theta, x).train(training_batch_size=1000);
# Obtain posterior samples for different number of iid xos.
num_samples = 1000
mcmc_parameters = dict(
num_chains=50,
thin=5,
warmup_steps=5,
init_strategy="proposal",
)
mcmc_method = "slice_np_vectorized"
posterior = inferer.build_posterior(
mcmc_method=mcmc_method,
mcmc_parameters=mcmc_parameters,
)
# Generate samples with MCMC given the same set of x_os as above.
nle_samples = posterior.sample(sample_shape=(num_samples,), x=xo)
nle_samples.isfinite().all(-1).sum() # its 999 instead of 1000
Thanks for creating this! Relevant for #910