sbi icon indicating copy to clipboard operation
sbi copied to clipboard

Integrate new density estimator interface into SBI

Open manuelgloeckler opened this issue 1 year ago • 5 comments

Description

Following #952 we need to integrate the new Density Estimator class into SBI

Checklist

Following steps:

  • [ ] 1) Change posterior_nn helper function to return the new Density Estimator class
  • [ ] 2) Change likelihood_nn helper function to return the new Density Estimator class
  • [ ] 3) Change classifier_nn helper function to return ( new Ratio Estimator Class ?)
  • [ ] 4) Integrate new interface into SNPE base classes.
  • [ ] 5) Integrate new interface into SNLE base classes.
  • [ ] 6) Integrate new interface into SNRE base classes.
  • [ ] 7) Try out a new density estimator i.e. from ZUKO.

These steps can be worked on almost in parallel.

manuelgloeckler avatar Feb 27 '24 09:02 manuelgloeckler

@gmoss13 I will start with 1) posterior_nn and 4) integration into SNPE base classes.

manuelgloeckler avatar Feb 28 '24 08:02 manuelgloeckler

Thanks for setting up the issue! I will begin with 2) likelihood_nn and 5) integration into SNLE base classes.

gmoss13 avatar Feb 28 '24 09:02 gmoss13

Actually, looking into it a bit: I think DensityEstimator needs to expose a few functions from torch.nn.Module. For instance, for any training method, we would need:

  • parameters: To get all the parameters to optimize.
  • to: To move tensors to device
  • train: To switch to train mode
  • eval: To switch to eval mode.
  • zero_grad: (optional) But currently used in SNPE base training

The easiest would be to make "DensityEstimator" or any general "Estimator" a subclass of nn.Module.

manuelgloeckler avatar Feb 28 '24 10:02 manuelgloeckler

I like the idea with nn.Module 👍 this makes it easier to integrate things with PPL frameworks as well.

janfb avatar Feb 28 '24 12:02 janfb

  • [ ] 3) Change classifier_nn helper function to return ( new Ratio Estimator Class ?)

Thinking about this some more, not sure how necessary this is? Currently only ratio estimators (I am aware of) are through classifiers, for which we just use nn.Modules directly. I'm not sure what would be added by adding a RatioEstimator wrapper.

  • [ ] 6) Integrate new interface into SNRE base classes.

This would only be necessary if we do 3.

gmoss13 avatar Mar 07 '24 15:03 gmoss13

Closing this in favour of #992 as everything else is addressed already.

gmoss13 avatar Mar 24 '24 13:03 gmoss13