sbi icon indicating copy to clipboard operation
sbi copied to clipboard

Interchangeability of `Callable` and `BasePotential`

Open schroedk opened this issue 1 year ago • 7 comments

I have a question regarding the interchangeability of the argument potential_fn of

class NeuralPosterior(ABC):
    r"""Posterior $p(\theta|x)$ with `log_prob()` and `sample()` methods.<br/><br/>
    All inference methods in sbi train a neural network which is then used to obtain
    the posterior distribution. The `NeuralPosterior` class wraps the trained network
    such that one can directly evaluate the (unnormalized) log probability and draw
    samples from the posterior.
    """

    def __init__(
        self,
        potential_fn: Union[Callable, BasePotential],

For the callable case, it must be something like:

def potential(theta=None, x0=None)
    ...

in contrast to BasePotential, which is a Callable with theta as positional argument and track_gradients as keyword argument, correct? Is this tested somewhere? I only found examples where the argument is of type BasePotential.

schroedk avatar Aug 19 '24 10:08 schroedk

Related #1055

schroedk avatar Aug 19 '24 10:08 schroedk

yes, correct.

The reason that the custom potential_fn has theta and x_o as args is that quantities are required to calculate the "potential", i.e., the unnormalized posterior probability.

For the BasePotential potential, the call method does not have x_o as arg, because it is set as property at runtime.

If a user passes a custom potential, then this is checked for the required args here:

https://github.com/sbi-dev/sbi/blob/593e1533738bdc9c747d50f502f7c1a47bf94248/sbi/inference/posteriors/base_posterior.py#L57-L69 and then wrapped as BasePotential here:

https://github.com/sbi-dev/sbi/blob/593e1533738bdc9c747d50f502f7c1a47bf94248/sbi/inference/potentials/base_potential.py#L80-L97

janfb avatar Aug 19 '24 15:08 janfb

Is this tested somewhere? I only found examples where the argument is of type BasePotential.

Yes, I had to dig a bit as well, but it's tested here:

https://github.com/sbi-dev/sbi/blob/593e1533738bdc9c747d50f502f7c1a47bf94248/tests/potential_test.py#L28-L37

Here, you can how we define a custom potential, depending on inputs theta and x_o.

janfb avatar Aug 19 '24 15:08 janfb

I think this can be closed, feel free to reopen if anything is still unclear!

michaeldeistler avatar Aug 29 '24 08:08 michaeldeistler

I think it's a good starting point for refactoring the Callable potential API.

janfb avatar Aug 29 '24 09:08 janfb

what what have to be done here? Just more docs?

michaeldeistler avatar Aug 29 '24 09:08 michaeldeistler

At the moment, if a user passes a just a Callable as potential, we test during runtime whether it has the required arguments, e.g., theta and track_gradients and x_o (or so). This is brittle. It would be nice to do this beforehand with types, e.g., define a Protocol to ensure that the passed Callable has the correct signature.

janfb avatar Aug 29 '24 09:08 janfb

More context for resolving this: it makes sense to define Potential as a Python protocol, instead of an ABC. This will give more flexibility in designing custom potentials without the need to directly inherit from a base class.

see, e.g., here for more details on the difference between abstract base classes and protocols: https://jellis18.github.io/post/2022-01-11-abc-vs-protocol/

janfb avatar Feb 11 '25 15:02 janfb