sbi icon indicating copy to clipboard operation
sbi copied to clipboard

New organization of `SNPE` methods

Open michaeldeistler opened this issue 1 year ago • 2 comments

-- inference ----- trainers --------- npe ------------- npe.py ------------- snpe_a_correction.py ------------- snpe_c_loss.py

Then, the API for NPE (amortized) is:

from sbi.inference import NPE, DirectPosterior

trainer = NPE()
net = trainer.append_simulations(theta, x).train()
posterior = DirectPosterior(net, prior)  # Or use `build_posterior()`

For SNPE_A, it is:

from sbi.inference import NPE, DirectPosterior, snpe_a_correction

for r in range(3):
    theta = proposal.sample((1000,))
    x = simulator(theta)

    trainer = NPE(density_estimator="Gaussian" if r < 2 else "mdn")
    net = trainer.append_simulations(theta, x).train()
    proposal_posterior = DirectPosterior(net, prior)  # Or use `build_posterior()`
    corrected_posterior = snpe_a_correction(proposal_posterior, proposal)
    proposal = corrected_posterior

For SNPE_C (atomic), it is:

from sbi.inference import NPE, DirectPosterior, snpe_c_atomic_loss

# First round is standard NPE.
theta, x = simulate_for_sbi(prior, simulator)
trainer = NPE()
net = trainer.append_simulations(theta, x).train()
proposal = DirectPosterior(net, prior).set_default_x(x_o)  # Or use `build_posterior()`

# Later rounds use the APT loss.
for _ in range(1, 3):
    theta, x = simulate_for_sbi(proposal, simulator)
    net = trainer.append_simulations(theta, x).train(loss=snpe_c_atomic_loss)
    proposal = DirectPosterior(net, prior).set_default_x(x_o)  # Or use `build_posterior()`

For SNPE_C (non-atomic), the only difference is that one would also pass proposal=proposal to append_simulations(), and one has to use MDNs.

michaeldeistler avatar Aug 28 '24 08:08 michaeldeistler

Some additional context. Currently, the inference structure is:

-- inference ----- trainers --------- npe ------------- npe_base.py ------------- npe_a.py ------------- npe_b.py (which is not implemented) ------------- npe_c.py

The difference between npe_a and npe_c is only in how we deal with the proposal over several rounds (i.e. via an atomic loss as in npe_c or by correcting the proposal as in npe_a. It would be good to seperate this correction from the classes so that it can be exposed to writing custom training loops as above. The overall inference structure should remain backwards compatible, i.e. the user should still be able to import NPE_C and NPE_A (and the SNPE_A SNPE_C aliases), which should function as before.

gmoss13 avatar Feb 26 '25 10:02 gmoss13

Can make a seperate issue, but I just noticed that the tutorials need updating with the new DirectPosterior API.

jnsbck avatar Apr 15 '25 14:04 jnsbck