sbi
sbi copied to clipboard
improve interface for custom density estimators passed to inference classes
In all inference methods (child classes of NeuralInference) we generally allow to either pass a string for building the density estimator internally, e.g., NPE(density_estimator="maf") or we allow passing a plain Callable and then try to build the network by calling density_estimator(theta, x) internally.
Two problems here:
- we should make it more explicit for the user that what to pass here, e.g., by defining a protocol
DensityEstimatorBuilderthat makes sure the custom density estimator builder actually returns an object we can work with downstream. As a blue print, seesbi/inference/trainers/npse/vector_field_inference.pyin #1497 - more generally, the naming is a bit confusing as pointed out by @StarostinV because it is not really an
density_estimatorbut rather adensity_estimator_build_fn. However, this would imply a central API change.
This might be a possible solution for sharing the protocol among different estimators (if I understood the protocol logic correctly):
ConditionalEstimatorType = TypeVar(
'ConditionalEstimatorType',
bound=ConditionalEstimator,
covariant=True,
)
class ConditionalEstimatorBuilder(Protocol[ConditionalEstimatorType]):
def __call__(self, theta: Tensor, x: Tensor) -> ConditionalEstimatorType:
...
vector_field_estimator_builder: ConditionalEstimatorBuilder[ConditionalVectorFieldEstimator]
density_estimator_builder: ConditionalEstimatorBuilder[ConditionalDensityEstimator]
fixed with #1633