sbi icon indicating copy to clipboard operation
sbi copied to clipboard

improve interface for custom density estimators passed to inference classes

Open janfb opened this issue 9 months ago • 1 comments

In all inference methods (child classes of NeuralInference) we generally allow to either pass a string for building the density estimator internally, e.g., NPE(density_estimator="maf") or we allow passing a plain Callable and then try to build the network by calling density_estimator(theta, x) internally.

Two problems here:

  1. we should make it more explicit for the user that what to pass here, e.g., by defining a protocol DensityEstimatorBuilder that makes sure the custom density estimator builder actually returns an object we can work with downstream. As a blue print, see sbi/inference/trainers/npse/vector_field_inference.py in #1497
  2. more generally, the naming is a bit confusing as pointed out by @StarostinV because it is not really an density_estimator but rather a density_estimator_build_fn. However, this would imply a central API change.

janfb avatar Mar 25 '25 10:03 janfb

This might be a possible solution for sharing the protocol among different estimators (if I understood the protocol logic correctly):

ConditionalEstimatorType = TypeVar(
    'ConditionalEstimatorType',
    bound=ConditionalEstimator,
    covariant=True,
)

class ConditionalEstimatorBuilder(Protocol[ConditionalEstimatorType]):
    def __call__(self, theta: Tensor, x: Tensor) -> ConditionalEstimatorType:
        ...

vector_field_estimator_builder: ConditionalEstimatorBuilder[ConditionalVectorFieldEstimator]
density_estimator_builder: ConditionalEstimatorBuilder[ConditionalDensityEstimator]

StarostinV avatar Mar 25 '25 13:03 StarostinV

fixed with #1633

janfb avatar Aug 14 '25 11:08 janfb