sbi icon indicating copy to clipboard operation
sbi copied to clipboard

Add kwargs to SNPE_A.build_posterior to match base class API

Open janfb opened this issue 1 year ago • 1 comments

The SNPE_A.build_posterior(...) method overrides that of the base class to correct for the proposal posterior.

Problem: As a side effect, it has a different signature than the base class method, e.g., it is not possible to pass things like sample_with, mcmc_method, mcmc_parameters. Although it is unlikely that such kwargs will be relevant when using SNPE_A, it would still be nice to keep the same API as the base class.

Solution: Add a **kwargs to SNPE_A.build_posterior(...) and then do

self._posterior = super.build_posterior(
            posterior_estimator=wrapped_density_estimator,  # type: ignore
            prior=prior,
            **kwargs,
        )

instead of building the posterior directly with DirectPosterior(...).

janfb avatar Mar 13 '24 14:03 janfb

Happy to take care of it.

zinaStef avatar Mar 14 '24 09:03 zinaStef

I also like to work on this if available. So I think @zinaStef already modified the arguments, does the method need to provide functions in base class? like adding condition to use mcmc or vi? @janfb

Ziaeemehr avatar Mar 18 '24 15:03 Ziaeemehr