sbi icon indicating copy to clipboard operation
sbi copied to clipboard

FMPE not working even on simple tasks?

Open yangyang-pro opened this issue 11 months ago • 4 comments

Hi,

Thanks for your work on this nice package!

I have been trying the implemented FMPE method sbi.inference.FMPE on a few SBI tasks. With the default parameter settings, I found that the performance is terrible, compared to the results reported in the paper (Flow Matching for Scalable Simulation-Based Inference.

For example, on the two-moons task from sbibm, FMPE with either mlp or resnet backend trained on 10000 simulations often only achieved about 0.9 of c2st accuracy, or even worse.

This is the posterior samples of the trained FMPE conditioned on the first observation from sbibm:

Image

I am using the following code snippet. There are some hydra and wandb configurations, but basically I was using the original FMPE implementations in SBI and didn't change any parameters.

import logging
import os

import hydra
import torch
from omegaconf import DictConfig, OmegaConf
from sbi.analysis import pairplot
from sbi.inference import FMPE

import sbibm
import wandb
from sbibm.metrics import c2st


@hydra.main(version_base=None, config_path="./configs", config_name="train_fmpe_sbibm")
def train(cfg: DictConfig):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    logging.info(f"Device: {device}")

    enable_wandb = cfg.wandb.enabled
    if enable_wandb:
        wandb.init(
            config=OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True),
            project=cfg.wandb.project,
            tags=cfg.wandb.tags,
            name=cfg.wandb.name,
            reinit=True,
        )

    task = sbibm.get_task(cfg.data.name)
    prior_gen = task.get_prior()
    prior_dist = task.get_prior_dist()
    simulator = task.get_simulator()
    observation = task.get_observation(num_observation=1)
    reference_samples = task.get_reference_posterior_samples(num_observation=1)

    theta_train = prior_gen(num_samples=cfg.data.num_training_samples)
    x_train = simulator(theta_train)

    theta_test = prior_gen(num_samples=cfg.data.num_test_samples)
    x_test = simulator(theta_test)

    model = FMPE(
        density_estimator="resnet",
        prior=prior_dist,
        device=device,
    )
    model.append_simulations(theta=theta_train.to(device), x=x_train.to(device)).train(
        force_first_round_loss=True
    )

    posterior = model.build_posterior()

    test_log_probs = posterior.log_prob_batched(
        theta_test.unsqueeze(0).to(device), x_test.to(device), norm_posterior=False
    )

    samples = posterior.sample(sample_shape=(len(reference_samples),), x=observation)
    c2st_accuracy = c2st(samples, reference_samples)

    if enable_wandb:
        wandb.run.summary["mean test log_prob"] = torch.mean(test_log_probs)
        wandb.run.summary["c2st"] = c2st_accuracy

    logging.info(f"mean test log_prob: {torch.mean(test_log_probs)}")
    logging.info(f"c2st: {c2st_accuracy}")

    fig, _ = pairplot(samples)
    if not os.path.exists("./results/figures"):
        os.makedirs("./results/figures")
    fig.suptitle(f"c2st={c2st_accuracy.item()}", fontsize=16)
    fig.savefig(
        "./results/figures/fmpe_sbibm_" + cfg.data.name + "_posterior_samples.png"
    )


if __name__ == "__main__":
    train()

yangyang-pro avatar Jan 21 '25 13:01 yangyang-pro

Thanks for reporting this!

Indeed, for the mini-sbibm we are working on, see #1335 , we observe similar performance on two-moons (with fewer simulations though).

We will have a look on the difference between the implementations here vs. in https://github.com/dingo-gw/flow-matching-posterior-estimation.

janfb avatar Jan 23 '25 08:01 janfb

Thanks for your response and good to know your ongoing work!

Just for more information, I also run the codebase in https://github.com/dingo-gw/flow-matching-posterior-estimation on two-moons. Its performance is also terrible (around 0.8 c2st accuracy) with 10000 simulations, which is much worse than the results in their paper. But it performs very well (about 0.55 c2st accuracy) with 100000 simulations.

yangyang-pro avatar Jan 23 '25 08:01 yangyang-pro

there was a potentially relevant bug fix: #1492

janfb avatar Mar 20 '25 08:03 janfb

See this PR #1497 for potential improvements. e.g., this comment about the architecture: https://github.com/sbi-dev/sbi/pull/1497#issuecomment-2751932097

janfb avatar Mar 25 '25 17:03 janfb

#1544 add bug fixes and improvements to these methods and fixed many performance issues.

janfb avatar Aug 15 '25 08:08 janfb