New organization of `SNPE` methods
-- inference ----- trainers --------- npe ------------- npe.py ------------- snpe_a_correction.py ------------- snpe_c_loss.py
Then, the API for NPE (amortized) is:
from sbi.inference import NPE, DirectPosterior
trainer = NPE()
net = trainer.append_simulations(theta, x).train()
posterior = DirectPosterior(net, prior) # Or use `build_posterior()`
For SNPE_A, it is:
from sbi.inference import NPE, DirectPosterior, snpe_a_correction
for r in range(3):
theta = proposal.sample((1000,))
x = simulator(theta)
trainer = NPE(density_estimator="Gaussian" if r < 2 else "mdn")
net = trainer.append_simulations(theta, x).train()
proposal_posterior = DirectPosterior(net, prior) # Or use `build_posterior()`
corrected_posterior = snpe_a_correction(proposal_posterior, proposal)
proposal = corrected_posterior
For SNPE_C (atomic), it is:
from sbi.inference import NPE, DirectPosterior, snpe_c_atomic_loss
# First round is standard NPE.
theta, x = simulate_for_sbi(prior, simulator)
trainer = NPE()
net = trainer.append_simulations(theta, x).train()
proposal = DirectPosterior(net, prior).set_default_x(x_o) # Or use `build_posterior()`
# Later rounds use the APT loss.
for _ in range(1, 3):
theta, x = simulate_for_sbi(proposal, simulator)
net = trainer.append_simulations(theta, x).train(loss=snpe_c_atomic_loss)
proposal = DirectPosterior(net, prior).set_default_x(x_o) # Or use `build_posterior()`
For SNPE_C (non-atomic), the only difference is that one would also pass proposal=proposal to append_simulations(), and one has to use MDNs.
Some additional context. Currently, the inference structure is:
-- inference ----- trainers --------- npe ------------- npe_base.py ------------- npe_a.py ------------- npe_b.py (which is not implemented) ------------- npe_c.py
The difference between npe_a and npe_c is only in how we deal with the proposal over several rounds (i.e. via an atomic loss as in npe_c or by correcting the proposal as in npe_a. It would be good to seperate this correction from the classes so that it can be exposed to writing custom training loops as above. The overall inference structure should remain backwards compatible, i.e. the user should still be able to import NPE_C and NPE_A (and the SNPE_A SNPE_C aliases), which should function as before.
Can make a seperate issue, but I just noticed that the tutorials need updating with the new DirectPosterior API.