API simplify DensityEstimator base class
Simplify the DensityEstimator base class and make it an abstract class.
Remove unnecessary functions
- [x] add
Estimatorbase class - [x] doc shape conventions
- [ ] remove broadcasting stuff from log_prob and sample methods
- [ ] add VectorFieldEstimator
- [ ] add RatioEstimator
- [ ] sync with @michaeldeistler 's #1066
A step in the direction of #1046
Codecov Report
Attention: Patch coverage is 84.84848% with 5 lines in your changes are missing coverage. Please review.
Project coverage is 76.86%. Comparing base (
0b5f931) to head (d47f4a7).
:exclamation: Current head d47f4a7 differs from pull request most recent head c6ff223. Consider uploading reports for the commit c6ff223 to get more accurate results
Additional details and impacted files
@@ Coverage Diff @@
## main #1072 +/- ##
==========================================
- Coverage 85.13% 76.86% -8.27%
==========================================
Files 90 89 -1
Lines 6651 6558 -93
==========================================
- Hits 5662 5041 -621
- Misses 989 1517 +528
| Flag | Coverage Δ | |
|---|---|---|
| unittests | 76.86% <84.84%> (-8.27%) |
:arrow_down: |
Flags with carried forward coverage won't be shown. Click here to find out more.
| Files | Coverage Δ | |
|---|---|---|
| sbi/inference/snpe/snpe_a.py | 65.15% <100.00%> (-25.07%) |
:arrow_down: |
| sbi/neural_nets/density_estimators/nflows_flow.py | 98.21% <100.00%> (+35.46%) |
:arrow_up: |
| sbi/neural_nets/density_estimators/zuko_flow.py | 97.91% <100.00%> (+33.47%) |
:arrow_up: |
| sbi/neural_nets/density_estimators/base.py | 65.00% <66.66%> (+9.44%) |
:arrow_up: |
| sbi/utils/user_input_checks.py | 85.36% <62.50%> (+1.60%) |
:arrow_up: |
Finally it's green ^^
Great :). I just checked with @michaeldeistler, who also works on the Density estimator functionality in #1066; so, we will delay merging this until #1066 is done.
#1138 conflicts with this PR. As this PR modifies the base class, merging it fast would be a good idea otherwise we should drop it as it will always need to resolve conflicts.
I am not sure how to solve the conflict here:
- From our discussion during the sprint, I thought the goal was to avoid having constraints on the base class and went on removing the attribute that were not used in it (
net/_condition_shape) - #1138 adds an attribute which is not used in the class.
I think if we have input/condition_shape attributes, it would make sense to add a mechanism to check these shapes in all the class function. something like:
def _check_shape(self, x=None, theta=None):
if x is not None:
check_input_shape(x, self._input_shape)
if theta is not None:
check_condition_shape(theta, self.condition_shape)
maybe this could also handle some of the reshaping from #1066?
Yes, this has a lot of conflicts with the changes to the shaping we made and we want this PR to be merged asap.
@manuelgloeckler planned to continuing this PR. I think he will either merge this branch into a new feature branch from main, or to start from scratch and cherry pick the changes from this PR.
I see the point about the _check_shape function, but I think we have that already, at least for the condition shape:
https://github.com/sbi-dev/sbi/blob/005aeaca3c0ee246a39ef78a1947bc1786a85f71/sbi/neural_nets/density_estimators/base.py#L132-L157
And we also have other functions that are used by the SBI methods to check the correctness of the shape, e.g., https://github.com/sbi-dev/sbi/blob/005aeaca3c0ee246a39ef78a1947bc1786a85f71/sbi/neural_nets/density_estimators/shape_handling.py#L53
I think this one can be closed in favor of #1151 or not @manuelgloeckler @tomMoral ?
Yes, closing this in favor of the #1151