FMPE not working even on simple tasks?
Hi,
Thanks for your work on this nice package!
I have been trying the implemented FMPE method sbi.inference.FMPE on a few SBI tasks. With the default parameter settings, I found that the performance is terrible, compared to the results reported in the paper (Flow Matching for Scalable Simulation-Based Inference.
For example, on the two-moons task from sbibm, FMPE with either mlp or resnet backend trained on 10000 simulations often only achieved about 0.9 of c2st accuracy, or even worse.
This is the posterior samples of the trained FMPE conditioned on the first observation from sbibm:
I am using the following code snippet. There are some hydra and wandb configurations, but basically I was using the original FMPE implementations in SBI and didn't change any parameters.
import logging
import os
import hydra
import torch
from omegaconf import DictConfig, OmegaConf
from sbi.analysis import pairplot
from sbi.inference import FMPE
import sbibm
import wandb
from sbibm.metrics import c2st
@hydra.main(version_base=None, config_path="./configs", config_name="train_fmpe_sbibm")
def train(cfg: DictConfig):
device = "cuda" if torch.cuda.is_available() else "cpu"
logging.info(f"Device: {device}")
enable_wandb = cfg.wandb.enabled
if enable_wandb:
wandb.init(
config=OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True),
project=cfg.wandb.project,
tags=cfg.wandb.tags,
name=cfg.wandb.name,
reinit=True,
)
task = sbibm.get_task(cfg.data.name)
prior_gen = task.get_prior()
prior_dist = task.get_prior_dist()
simulator = task.get_simulator()
observation = task.get_observation(num_observation=1)
reference_samples = task.get_reference_posterior_samples(num_observation=1)
theta_train = prior_gen(num_samples=cfg.data.num_training_samples)
x_train = simulator(theta_train)
theta_test = prior_gen(num_samples=cfg.data.num_test_samples)
x_test = simulator(theta_test)
model = FMPE(
density_estimator="resnet",
prior=prior_dist,
device=device,
)
model.append_simulations(theta=theta_train.to(device), x=x_train.to(device)).train(
force_first_round_loss=True
)
posterior = model.build_posterior()
test_log_probs = posterior.log_prob_batched(
theta_test.unsqueeze(0).to(device), x_test.to(device), norm_posterior=False
)
samples = posterior.sample(sample_shape=(len(reference_samples),), x=observation)
c2st_accuracy = c2st(samples, reference_samples)
if enable_wandb:
wandb.run.summary["mean test log_prob"] = torch.mean(test_log_probs)
wandb.run.summary["c2st"] = c2st_accuracy
logging.info(f"mean test log_prob: {torch.mean(test_log_probs)}")
logging.info(f"c2st: {c2st_accuracy}")
fig, _ = pairplot(samples)
if not os.path.exists("./results/figures"):
os.makedirs("./results/figures")
fig.suptitle(f"c2st={c2st_accuracy.item()}", fontsize=16)
fig.savefig(
"./results/figures/fmpe_sbibm_" + cfg.data.name + "_posterior_samples.png"
)
if __name__ == "__main__":
train()
Thanks for reporting this!
Indeed, for the mini-sbibm we are working on, see #1335 , we observe similar performance on two-moons (with fewer simulations though).
We will have a look on the difference between the implementations here vs. in https://github.com/dingo-gw/flow-matching-posterior-estimation.
Thanks for your response and good to know your ongoing work!
Just for more information, I also run the codebase in https://github.com/dingo-gw/flow-matching-posterior-estimation on two-moons. Its performance is also terrible (around 0.8 c2st accuracy) with 10000 simulations, which is much worse than the results in their paper. But it performs very well (about 0.55 c2st accuracy) with 100000 simulations.
there was a potentially relevant bug fix: #1492
See this PR #1497 for potential improvements. e.g., this comment about the architecture: https://github.com/sbi-dev/sbi/pull/1497#issuecomment-2751932097
#1544 add bug fixes and improvements to these methods and fixed many performance issues.