sbi icon indicating copy to clipboard operation
sbi copied to clipboard

API simplify DensityEstimator base class

Open tomMoral opened this issue 1 year ago • 5 comments

Simplify the DensityEstimator base class and make it an abstract class. Remove unnecessary functions

  • [x] add Estimator base class
  • [x] doc shape conventions
  • [ ] remove broadcasting stuff from log_prob and sample methods
  • [ ] add VectorFieldEstimator
  • [ ] add RatioEstimator
  • [ ] sync with @michaeldeistler 's #1066

A step in the direction of #1046

tomMoral avatar Mar 21 '24 08:03 tomMoral

Codecov Report

Attention: Patch coverage is 84.84848% with 5 lines in your changes are missing coverage. Please review.

Project coverage is 76.86%. Comparing base (0b5f931) to head (d47f4a7).

:exclamation: Current head d47f4a7 differs from pull request most recent head c6ff223. Consider uploading reports for the commit c6ff223 to get more accurate results

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1072      +/-   ##
==========================================
- Coverage   85.13%   76.86%   -8.27%     
==========================================
  Files          90       89       -1     
  Lines        6651     6558      -93     
==========================================
- Hits         5662     5041     -621     
- Misses        989     1517     +528     
Flag Coverage Δ
unittests 76.86% <84.84%> (-8.27%) :arrow_down:

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Coverage Δ
sbi/inference/snpe/snpe_a.py 65.15% <100.00%> (-25.07%) :arrow_down:
sbi/neural_nets/density_estimators/nflows_flow.py 98.21% <100.00%> (+35.46%) :arrow_up:
sbi/neural_nets/density_estimators/zuko_flow.py 97.91% <100.00%> (+33.47%) :arrow_up:
sbi/neural_nets/density_estimators/base.py 65.00% <66.66%> (+9.44%) :arrow_up:
sbi/utils/user_input_checks.py 85.36% <62.50%> (+1.60%) :arrow_up:

... and 51 files with indirect coverage changes

codecov[bot] avatar Mar 21 '24 10:03 codecov[bot]

Finally it's green ^^

tomMoral avatar Mar 21 '24 12:03 tomMoral

Great :). I just checked with @michaeldeistler, who also works on the Density estimator functionality in #1066; so, we will delay merging this until #1066 is done.

manuelgloeckler avatar Mar 22 '24 12:03 manuelgloeckler

#1138 conflicts with this PR. As this PR modifies the base class, merging it fast would be a good idea otherwise we should drop it as it will always need to resolve conflicts.

I am not sure how to solve the conflict here:

  • From our discussion during the sprint, I thought the goal was to avoid having constraints on the base class and went on removing the attribute that were not used in it (net/_condition_shape)
  • #1138 adds an attribute which is not used in the class.

I think if we have input/condition_shape attributes, it would make sense to add a mechanism to check these shapes in all the class function. something like:

def _check_shape(self, x=None, theta=None):
    if x is not None:
        check_input_shape(x, self._input_shape)
    if theta is not None:
        check_condition_shape(theta, self.condition_shape)

maybe this could also handle some of the reshaping from #1066?

tomMoral avatar Apr 24 '24 20:04 tomMoral

Yes, this has a lot of conflicts with the changes to the shaping we made and we want this PR to be merged asap.

@manuelgloeckler planned to continuing this PR. I think he will either merge this branch into a new feature branch from main, or to start from scratch and cherry pick the changes from this PR.

I see the point about the _check_shape function, but I think we have that already, at least for the condition shape: https://github.com/sbi-dev/sbi/blob/005aeaca3c0ee246a39ef78a1947bc1786a85f71/sbi/neural_nets/density_estimators/base.py#L132-L157

And we also have other functions that are used by the SBI methods to check the correctness of the shape, e.g., https://github.com/sbi-dev/sbi/blob/005aeaca3c0ee246a39ef78a1947bc1786a85f71/sbi/neural_nets/density_estimators/shape_handling.py#L53

janfb avatar Apr 25 '24 09:04 janfb

I think this one can be closed in favor of #1151 or not @manuelgloeckler @tomMoral ?

janfb avatar Jun 03 '24 14:06 janfb

Yes, closing this in favor of the #1151

tomMoral avatar Jun 06 '24 10:06 tomMoral