feat: batched sampling for MCMC
What does this implement/fix? Explain your changes
This pull request aims to implement the sample_batched method for MCMC.
Current problem
BasePotentialcan either "allow_iid" or not. Hence, each batch dimension will be interpreted as IID samples.- [x] Replace
allow_iidwith a mutable attribute (or optional input argument)interpret_as_iid. - [ ] Remove warning for batched x and default to batched evaluation
- [x] Replace
- Refactor all MCMC initialization methods to work with batch dim.
- [ ] resample should break
- [ ] SIR should break
- [x] proposal should work
- Add tests to check if correct samples are in each dimension (currently, only shapes are checked)
- [x] The problem is currently not catched by tests...
The current implementation will let you sample the correct shape, BUT will output the wrong solution. This is because the potential function will broadcast, repeat and finally sum up the first dimension which is incorrect.
Codecov Report
Attention: Patch coverage is 82.72727% with 19 lines in your changes missing coverage. Please review.
Project coverage is 75.62%. Comparing base (
2398a7a) to head (813ee75). Report is 8 commits behind head on main.
Additional details and impacted files
@@ Coverage Diff @@
## main #1176 +/- ##
==========================================
- Coverage 84.55% 75.62% -8.94%
==========================================
Files 96 96
Lines 7603 7701 +98
==========================================
- Hits 6429 5824 -605
- Misses 1174 1877 +703
| Flag | Coverage Δ | |
|---|---|---|
| unittests | 75.62% <82.72%> (-8.94%) |
:arrow_down: |
Flags with carried forward coverage won't be shown. Click here to find out more.
| Files | Coverage Δ | |
|---|---|---|
| sbi/inference/posteriors/base_posterior.py | 86.04% <100.00%> (ø) |
|
| ...inference/potentials/likelihood_based_potential.py | 100.00% <100.00%> (ø) |
|
| sbi/inference/potentials/ratio_based_potential.py | 100.00% <100.00%> (ø) |
|
| sbi/utils/sbiutils.py | 78.35% <100.00%> (-8.21%) |
:arrow_down: |
| sbi/utils/user_input_checks.py | 76.31% <100.00%> (-7.20%) |
:arrow_down: |
| sbi/inference/abc/mcabc.py | 15.87% <0.00%> (-68.26%) |
:arrow_down: |
| sbi/inference/abc/smcabc.py | 12.44% <0.00%> (-69.96%) |
:arrow_down: |
| sbi/inference/potentials/base_potential.py | 92.85% <85.71%> (+0.35%) |
:arrow_up: |
| .../inference/potentials/posterior_based_potential.py | 95.23% <93.33%> (-1.74%) |
:arrow_down: |
| sbi/inference/posteriors/ensemble_posterior.py | 50.00% <0.00%> (-37.97%) |
:arrow_down: |
| ... and 2 more |
I've made some progress now towards this PR, and would like some feedback before I continue.
BasePotential can either "allow_iid" or not.
Given batch_dim_theta!=batch_dim_x, we need to decide how to interpret how to evaluate potential(x,theta). We could return (batch_dim_x,batch_dim_theta) potentials (i.e. every combination), but I am worried this can add a lot of computational overhead, especially when sampling. Instead, the current implementation I suggest that we assume that batch_dim_theta is a multiple of batch_dim_x (i.e. for sampling, we have n chains in theta for each x). In this case we expand the batch dim of x to batch_theta, and match which x goes to which theta. If we are happy with this approach, I'll go ahead and apply this also to the MCMC init_strategy, etc., and make sure this is consistent with other calls.
Remove warning for batched x and default to batched evaluation Not sure if we want batched evaluation as the default. I think it's easier to do batched evaluation when
sample_batchedorlog_prob_batchedis called, and otherwise assume iid (and warn if batch dim >1 as before).
Great, it looks good. I like that the choice on iid or not can now be made at the set_x method which makes a lot of sense.
I would also opt for your suggested option. The question arises because we squeeze the batch_shape into a single dimension, right? For "PyTorch" broadcasting, one would expect something like (1,batch_x_dim, x_dim) and (batch_theta_dim, betach_x_dim, theta_dim) -> (batch_x_dim, batch_theta_dim), so by squeezing the xs, thetas into 2d one would always get a dimension that is a multiple of batch_x_dim (otherwise it cannot be represented by a fixed size tensor).
For (1,batch_x_dim,x_dim) and (batch_theta_dim, 1, theta_dim), PyTorch broadcasting semantics would compute all combinations. Unfortunately, after squeezing, these distinctions between cases can no longer be fully preserved.
Great effort, thanks a lot for tacking this 👏
I do have a couple of comments and questions. Happy to discuss in person if needed.
Thanks for the review! I implemented your suggestions.
An additional point - For posterior_based_potential, indeed we should not allow for iid_x, as this is handled by PermutationInvariantNetwork. Instead, we now always treat x batches as not iid. If the user tries to set potential.set_x(x,x_is_iid=True) with a PosteriorBasedPotential, we raise an error stating this. I added a few test cases in embedding_net_test.py::test_embedding_api_with_multiple_trials to test whether batches of x are interpreted correctly when we use a PermutationInvariantNetwork.
closes #990 closes #944