sbi icon indicating copy to clipboard operation
sbi copied to clipboard

Add tutorial for training interface

Open michaeldeistler opened this issue 1 year ago • 1 comments

from sbi.neural_nets.flow import build_nsf
from sbi.neural_nets.density_estimators.flow import NFlowsFlow
from sbi.inference.posteriors import MCMCPosterior
from sbi.inference.potentials import likelihood_estimator_based_potential


# Build neural density estimator.
net = build_nsf(x, theta)
de = NFlowsFlow(net, condition_shape=(theta.shape[1],))

# Train the density estimator.
opt = Adam(list(de.parameters()), lr=5e-4)
for _ in range(100):
    opt.zero_grad()
    log_probs = de.log_prob(x, condition=theta)
    loss = -torch.mean(log_probs)
    loss.backward()
    opt.step()

# Build posterior and sample with MCMC.
potential, tf = likelihood_estimator_based_potential(de.net, prior, x_o)
posterior = MCMCPosterior(
    potential,
    proposal=prior,
    theta_transform=tf,
    num_chains=100,
    method="slice_np_vectorized"
)
samples = posterior.sample((1000,), x=x_o)
_ = pairplot(samples, limits=[[-3, 3], [-3, 3]], figsize=(3, 3))

TODO:

  • [ ] have an abstraction that avoids two lines for the density estimator (not sure about this one...)
  • [ ] Make the DirectPosterior and all potentials use the DensityEstimator abstraction
  • [ ] Make it easier to use NRE by having an abstraction for the loss

michaeldeistler avatar Mar 11 '24 15:03 michaeldeistler

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 76.87%. Comparing base (b3254ed) to head (827df2a). Report is 11 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #983      +/-   ##
==========================================
- Coverage   85.44%   76.87%   -8.57%     
==========================================
  Files         101      101              
  Lines        7941     7945       +4     
==========================================
- Hits         6785     6108     -677     
- Misses       1156     1837     +681     
Flag Coverage Δ
unittests 76.87% <ø> (-8.57%) :arrow_down:

Flags with carried forward coverage won't be shown. Click here to find out more.

see 25 files with indirect coverage changes

codecov[bot] avatar Mar 11 '24 15:03 codecov[bot]

May I ask if you could provide a minimal working example for the code given above? :) In the current version I really struggle to get the dimensions right during the log_probs = de.log_prob(x, condition=theta) call.

Kojobu avatar Jul 17 '24 13:07 Kojobu