sbi icon indicating copy to clipboard operation
sbi copied to clipboard

feat: batched sampling for MCMC

Open manuelgloeckler opened this issue 1 year ago • 3 comments

What does this implement/fix? Explain your changes

This pull request aims to implement the sample_batched method for MCMC.

Current problem

  • BasePotential can either "allow_iid" or not. Hence, each batch dimension will be interpreted as IID samples.
    • [x] Replace allow_iid with a mutable attribute (or optional input argument) interpret_as_iid.
    • [ ] Remove warning for batched x and default to batched evaluation
  • Refactor all MCMC initialization methods to work with batch dim.
    • [ ] resample should break
    • [ ] SIR should break
    • [x] proposal should work
  • Add tests to check if correct samples are in each dimension (currently, only shapes are checked)
    • [x] The problem is currently not catched by tests...

The current implementation will let you sample the correct shape, BUT will output the wrong solution. This is because the potential function will broadcast, repeat and finally sum up the first dimension which is incorrect.

manuelgloeckler avatar Jun 18 '24 07:06 manuelgloeckler

Codecov Report

Attention: Patch coverage is 82.72727% with 19 lines in your changes missing coverage. Please review.

Project coverage is 75.62%. Comparing base (2398a7a) to head (813ee75). Report is 8 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1176      +/-   ##
==========================================
- Coverage   84.55%   75.62%   -8.94%     
==========================================
  Files          96       96              
  Lines        7603     7701      +98     
==========================================
- Hits         6429     5824     -605     
- Misses       1174     1877     +703     
Flag Coverage Δ
unittests 75.62% <82.72%> (-8.94%) :arrow_down:

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Coverage Δ
sbi/inference/posteriors/base_posterior.py 86.04% <100.00%> (ø)
...inference/potentials/likelihood_based_potential.py 100.00% <100.00%> (ø)
sbi/inference/potentials/ratio_based_potential.py 100.00% <100.00%> (ø)
sbi/utils/sbiutils.py 78.35% <100.00%> (-8.21%) :arrow_down:
sbi/utils/user_input_checks.py 76.31% <100.00%> (-7.20%) :arrow_down:
sbi/inference/abc/mcabc.py 15.87% <0.00%> (-68.26%) :arrow_down:
sbi/inference/abc/smcabc.py 12.44% <0.00%> (-69.96%) :arrow_down:
sbi/inference/potentials/base_potential.py 92.85% <85.71%> (+0.35%) :arrow_up:
.../inference/potentials/posterior_based_potential.py 95.23% <93.33%> (-1.74%) :arrow_down:
sbi/inference/posteriors/ensemble_posterior.py 50.00% <0.00%> (-37.97%) :arrow_down:
... and 2 more

... and 20 files with indirect coverage changes

codecov[bot] avatar Jun 18 '24 07:06 codecov[bot]

I've made some progress now towards this PR, and would like some feedback before I continue.

BasePotential can either "allow_iid" or not.

Given batch_dim_theta!=batch_dim_x, we need to decide how to interpret how to evaluate potential(x,theta). We could return (batch_dim_x,batch_dim_theta) potentials (i.e. every combination), but I am worried this can add a lot of computational overhead, especially when sampling. Instead, the current implementation I suggest that we assume that batch_dim_theta is a multiple of batch_dim_x (i.e. for sampling, we have n chains in theta for each x). In this case we expand the batch dim of x to batch_theta, and match which x goes to which theta. If we are happy with this approach, I'll go ahead and apply this also to the MCMC init_strategy, etc., and make sure this is consistent with other calls.

Remove warning for batched x and default to batched evaluation Not sure if we want batched evaluation as the default. I think it's easier to do batched evaluation when sample_batched or log_prob_batched is called, and otherwise assume iid (and warn if batch dim >1 as before).

gmoss13 avatar Jun 27 '24 16:06 gmoss13

Great, it looks good. I like that the choice on iid or not can now be made at the set_x method which makes a lot of sense.

I would also opt for your suggested option. The question arises because we squeeze the batch_shape into a single dimension, right? For "PyTorch" broadcasting, one would expect something like (1,batch_x_dim, x_dim) and (batch_theta_dim, betach_x_dim, theta_dim) -> (batch_x_dim, batch_theta_dim), so by squeezing the xs, thetas into 2d one would always get a dimension that is a multiple of batch_x_dim (otherwise it cannot be represented by a fixed size tensor).

For (1,batch_x_dim,x_dim) and (batch_theta_dim, 1, theta_dim), PyTorch broadcasting semantics would compute all combinations. Unfortunately, after squeezing, these distinctions between cases can no longer be fully preserved.

manuelgloeckler avatar Jun 28 '24 12:06 manuelgloeckler

Great effort, thanks a lot for tacking this 👏

I do have a couple of comments and questions. Happy to discuss in person if needed.

Thanks for the review! I implemented your suggestions.

An additional point - For posterior_based_potential, indeed we should not allow for iid_x, as this is handled by PermutationInvariantNetwork. Instead, we now always treat x batches as not iid. If the user tries to set potential.set_x(x,x_is_iid=True) with a PosteriorBasedPotential, we raise an error stating this. I added a few test cases in embedding_net_test.py::test_embedding_api_with_multiple_trials to test whether batches of x are interpreted correctly when we use a PermutationInvariantNetwork.

gmoss13 avatar Jul 19 '24 15:07 gmoss13

closes #990 closes #944

janfb avatar Jul 30 '24 09:07 janfb