sbi
sbi copied to clipboard
Refactor abstract classes for custom density estimators
Initiated by @tomMoral's input in #1019 we are planning to give more flexibility to users for defining their custom density estimator, by adding another layer of abstraction -- an Estimator
base class.
Here is a draft resulting from a discussion with @michaeldeistler and @manuelgloeckler
indent level show inheritance, class methods in parentheses:
see also #1041
will be relevant for #963 and the RatioEstimator
as well @bkmi
@michaeldeistler @janfb @manuelgloeckler @jnsbck
I propose we remove the method loss
from this abstraction. The estimator should be able to be defined without assuming a way to train it.
This is an issue for MDN, all ratio estimators, and I suspect it will be an issue for the vector-based estimators as well.
MDN has multiple ways to train it depending on whether it is being incorporated into SNPE_A or SNPE_C. The ratio estimators only change in the way they are trained--NOT in the features of the estimator itself. Similarly, flow matching and score matching will require extremely similar estimators but have different losses.
At the time you instantiate an Estimator
, I argue that it should be agnostic to the training algorithm, otherwise why not include this abstraction in the training algorithm itself?
The loss should be at the "inference" level (i.e. class SNRE_A, SNPE_B, etc.), rather than at the estimator level.
What do you all think?
I think abstracting model, loss and optimization/training seperately makes a lot of sense. Would require a ton of changes to the code in inference
though I think.
fyi I think the answer to this was to let loss
exist in DensityEstimator
, but not in the other ones.
This is solved now: We have ConditionalEstimator
abstract base class in neural_nets/density_estimators
that takes care of shapes and requires children to implement loss
, log_prob
and sample
. ConditionalDensityEstimator
is the class for most flows (nflows
and zuko
). And ConditionalVectorFieldEstimator
will be the class for score matching and (maybe) flow matching methods.
The RatioEstimator
is separated from that and lives in its own ratio_estimators.py
bubble.
The neural_nets
module is still kind of a mess I think and should be refactored in the future, see #1190