sbi icon indicating copy to clipboard operation
sbi copied to clipboard

Refactor abstract classes for custom density estimators

Open janfb opened this issue 11 months ago • 4 comments

Initiated by @tomMoral's input in #1019 we are planning to give more flexibility to users for defining their custom density estimator, by adding another layer of abstraction -- an Estimator base class.

Here is a draft resulting from a discussion with @michaeldeistler and @manuelgloeckler

indent level show inheritance, class methods in parentheses: image

see also #1041

janfb avatar Mar 19 '24 20:03 janfb

will be relevant for #963 and the RatioEstimator as well @bkmi

janfb avatar Mar 19 '24 20:03 janfb

@michaeldeistler @janfb @manuelgloeckler @jnsbck

I propose we remove the method loss from this abstraction. The estimator should be able to be defined without assuming a way to train it.

This is an issue for MDN, all ratio estimators, and I suspect it will be an issue for the vector-based estimators as well.

MDN has multiple ways to train it depending on whether it is being incorporated into SNPE_A or SNPE_C. The ratio estimators only change in the way they are trained--NOT in the features of the estimator itself. Similarly, flow matching and score matching will require extremely similar estimators but have different losses.

At the time you instantiate an Estimator, I argue that it should be agnostic to the training algorithm, otherwise why not include this abstraction in the training algorithm itself?

The loss should be at the "inference" level (i.e. class SNRE_A, SNPE_B, etc.), rather than at the estimator level.

What do you all think?

bkmi avatar Mar 20 '24 17:03 bkmi

I think abstracting model, loss and optimization/training seperately makes a lot of sense. Would require a ton of changes to the code in inference though I think.

jnsbck avatar Mar 21 '24 09:03 jnsbck

fyi I think the answer to this was to let loss exist in DensityEstimator, but not in the other ones.

bkmi avatar Mar 21 '24 15:03 bkmi

This is solved now: We have ConditionalEstimator abstract base class in neural_nets/density_estimators that takes care of shapes and requires children to implement loss, log_prob and sample. ConditionalDensityEstimator is the class for most flows (nflows and zuko). And ConditionalVectorFieldEstimator will be the class for score matching and (maybe) flow matching methods.

The RatioEstimator is separated from that and lives in its own ratio_estimators.py bubble.

The neural_nets module is still kind of a mess I think and should be refactored in the future, see #1190

janfb avatar Jul 09 '24 16:07 janfb