sbi
sbi copied to clipboard
Add tutorial for training interface
from sbi.neural_nets.flow import build_nsf
from sbi.neural_nets.density_estimators.flow import NFlowsFlow
from sbi.inference.posteriors import MCMCPosterior
from sbi.inference.potentials import likelihood_estimator_based_potential
# Build neural density estimator.
net = build_nsf(x, theta)
de = NFlowsFlow(net, condition_shape=(theta.shape[1],))
# Train the density estimator.
opt = Adam(list(de.parameters()), lr=5e-4)
for _ in range(100):
opt.zero_grad()
log_probs = de.log_prob(x, condition=theta)
loss = -torch.mean(log_probs)
loss.backward()
opt.step()
# Build posterior and sample with MCMC.
potential, tf = likelihood_estimator_based_potential(de.net, prior, x_o)
posterior = MCMCPosterior(
potential,
proposal=prior,
theta_transform=tf,
num_chains=100,
method="slice_np_vectorized"
)
samples = posterior.sample((1000,), x=x_o)
_ = pairplot(samples, limits=[[-3, 3], [-3, 3]], figsize=(3, 3))
TODO:
- [ ] have an abstraction that avoids two lines for the density estimator (not sure about this one...)
- [ ] Make the
DirectPosteriorand allpotentialsuse theDensityEstimatorabstraction - [ ] Make it easier to use NRE by having an abstraction for the loss
Codecov Report
All modified and coverable lines are covered by tests :white_check_mark:
Project coverage is 76.87%. Comparing base (
b3254ed) to head (827df2a). Report is 11 commits behind head on main.
Additional details and impacted files
@@ Coverage Diff @@
## main #983 +/- ##
==========================================
- Coverage 85.44% 76.87% -8.57%
==========================================
Files 101 101
Lines 7941 7945 +4
==========================================
- Hits 6785 6108 -677
- Misses 1156 1837 +681
| Flag | Coverage Δ | |
|---|---|---|
| unittests | 76.87% <ø> (-8.57%) |
:arrow_down: |
Flags with carried forward coverage won't be shown. Click here to find out more.
May I ask if you could provide a minimal working example for the code given above? :) In the current version I really struggle to get the dimensions right during the log_probs = de.log_prob(x, condition=theta) call.