sbi icon indicating copy to clipboard operation
sbi copied to clipboard

RatioEstimator abstraction

Open bkmi opened this issue 1 year ago • 7 comments

What does this implement/fix? Explain your changes

It introduces a RatioEstimator abstraction. It will wrap neural networks that process x and theta to estimate ratios. The base class is suggestive of how to create extensions that use alternative data types (dict, etc.) while flexible enough to handle how embedding works and how embedded data is combined. Natural extensions include ratio estimators with specific network architectures such as transformers or convnets.

This implements #992 which is the ratio specific issue for #1046. It also implements the ratio part of #957 and makes the documentation confirm to the shape goals of #1041.

Does this close any currently open issues?

Closes: #992

Also #1036 since it eliminates a confusingly name and now-irrelevant class.

The others are more general and require more work on subjects related to ratios before closing.

Any relevant code examples, logs, error output, etc?

Nope

Any other comments?

There is a test that doesn't pass, namely #1090, but I think it is not because of my changes.

Checklist

Put an x in the boxes that apply. You can also fill these out after creating the PR. If you're unsure about any of them, don't hesitate to ask. We're here to help! This is simply a reminder of what we are going to look for before merging your code.

  • [x] I have read and understood the contribution guidelines
  • [x] I agree with re-licensing my contribution from AGPLv3 to Apache-2.0.
  • [x] I have commented my code, particularly in hard-to-understand areas
  • [x] I have added tests that prove my fix is effective or that my feature works
  • [x] I have reported how long the new tests run and potentially marked them with pytest.mark.slow.
  • [x] New and existing unit tests pass locally with my changes
  • [x] I performed linting and formatting as described in the contribution guidelines
  • [x] I rebased on main (or there are no conflicts with main)

bkmi avatar Mar 22 '24 10:03 bkmi

I added a test. it's extremely simple, but it checks that the estimated ratios are the correct size . (like log_prob).

bkmi avatar Mar 22 '24 11:03 bkmi

I will follow up after Friday.

bkmi avatar Mar 25 '24 09:03 bkmi

@tomMoral I think this is done now.

bkmi avatar Apr 09 '24 13:04 bkmi

Codecov Report

Attention: Patch coverage is 94.91525% with 3 lines in your changes missing coverage. Please review.

Project coverage is 75.61%. Comparing base (6fd2a6b) to head (8884241).

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1097      +/-   ##
==========================================
- Coverage   84.54%   75.61%   -8.93%     
==========================================
  Files          95       96       +1     
  Lines        7576     7603      +27     
==========================================
- Hits         6405     5749     -656     
- Misses       1171     1854     +683     
Flag Coverage Δ
unittests 75.61% <94.91%> (-8.93%) :arrow_down:

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Coverage Δ
sbi/inference/potentials/ratio_based_potential.py 100.00% <100.00%> (ø)
sbi/inference/snre/snre_base.py 94.44% <100.00%> (ø)
sbi/neural_nets/__init__.py 100.00% <ø> (ø)
sbi/neural_nets/classifier.py 100.00% <100.00%> (ø)
sbi/neural_nets/ratio_estimators.py 92.68% <92.68%> (ø)

... and 24 files with indirect coverage changes

codecov[bot] avatar Apr 09 '24 13:04 codecov[bot]

~~There is a wip push, but supporting sample dim will require:~~ ~~- [ ] making warnings that it is not allowed to be used in loss function computations~~

bkmi avatar Apr 22 '24 08:04 bkmi

goals:

  • [x] no abstraction, only one RatioEstimator
  • ~~[ ] move RatioEstimator out of neural_nets~~
  • [x] primarily RatioEstimator keeps track of x and theta shapes
  • [x] simple check shapes, i.e. (1, 1, *event_shape) for log_prob equivalent.
  • [x] RatioEstimator can have embeddings
  • [x] make sure that the loss is computed in the learning algorithm
  • [x] update tests

relevant issues: #1066 #1149

bkmi avatar Apr 25 '24 12:04 bkmi

As commented in #1103 , please make sure the all relevant (non renaming related changes) from #1103 are moved to this PR (I think they are all in here anyways).

Otherwise I think that this PR is almost done, no? The RatioEstimator is now independent from the ConditionalEstimator class family. Two points remain I believe:

  • The RatioEstimator does its own shape checks etc, but are the shape conventions the same? Are all the shape checks passing for SNRE as well?
  • there should be a test case for 2D inputs with embedding nets.

Thanks a lot! 🙏

janfb avatar Jul 02 '24 12:07 janfb

I will no longer remove RatioEstimator from neural_nets because it looks like the conditional density estimators still live there... unless someone tells me that's wrong. Now it has its own ratio_estimators.py file under neural nets.

It looks like zukoflow expects that the batch_theta and batch_x have shape (batch, *shape). I used that convention to determine x_shape and theta_shape, as well. i.e. theta_shape = batch_theta.shape[0] and x_shape = batch_x.shape[0]

I support sample_dim, and batch_dim by forcing the inputs to the RatioEstimator to have the same prefix. No broadcasting is allowed. When the unnormalized_log_ratio is computed, the prefix shape gets flattened into a single effective batch, then unflattened back to the prefix shape.

bkmi avatar Jul 08 '24 11:07 bkmi

if this passes the CI, all tests that check anything related to ratios are passing on my computer as well:

pytest tests/inference_on_device_test.py tests/linearGaussian_snre_test.py tests/ratio_estimator_test.py 
................xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx................xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx.x......................................................................................................... [ 59%]
.......................................................xx....................................................................................                                                               [100%]

bkmi avatar Jul 08 '24 13:07 bkmi

Right now test_lc2st_true_positiv_rate[LC2ST_NF] is failing, but that shouldn't affect merging this pull request @janfb

bkmi avatar Jul 08 '24 20:07 bkmi