Integrate new density estimator interface into SBI
Description
Following #952 we need to integrate the new Density Estimator class into SBI
Checklist
Following steps:
- [ ] 1) Change
posterior_nnhelper function to return the new Density Estimator class - [ ] 2) Change
likelihood_nnhelper function to return the new Density Estimator class - [ ] 3) Change
classifier_nnhelper function to return ( new Ratio Estimator Class ?) - [ ] 4) Integrate new interface into SNPE base classes.
- [ ] 5) Integrate new interface into SNLE base classes.
- [ ] 6) Integrate new interface into SNRE base classes.
- [ ] 7) Try out a new density estimator i.e. from ZUKO.
These steps can be worked on almost in parallel.
@gmoss13 I will start with 1) posterior_nn and 4) integration into SNPE base classes.
Thanks for setting up the issue! I will begin with 2) likelihood_nn and 5) integration into SNLE base classes.
Actually, looking into it a bit: I think DensityEstimator needs to expose a few functions from torch.nn.Module. For instance, for any training method, we would need:
parameters: To get all the parameters to optimize.to: To move tensors to devicetrain: To switch to train modeeval: To switch to eval mode.zero_grad: (optional) But currently used in SNPE base training
The easiest would be to make "DensityEstimator" or any general "Estimator" a subclass of nn.Module.
I like the idea with nn.Module 👍 this makes it easier to integrate things with PPL frameworks as well.
- [ ] 3) Change
classifier_nnhelper function to return ( new Ratio Estimator Class ?)
Thinking about this some more, not sure how necessary this is? Currently only ratio estimators (I am aware of) are through classifiers, for which we just use nn.Modules directly. I'm not sure what would be added by adding a RatioEstimator wrapper.
- [ ] 6) Integrate new interface into SNRE base classes.
This would only be necessary if we do 3.
Closing this in favour of #992 as everything else is addressed already.