Add new MCMC samplers
-
PyMC
-
torch library for MCMC?
-
compile with torch-compile?
-
numpyro: https://github.com/pyro-ppl/numpyro
another one: https://github.com/blackjax-devs/blackjax
maybe we can write a wrapper to jax or use this workaround: https://blackjax-devs.github.io/blackjax/examples/howto_other_frameworks.html
Additional context:
We can include more options for MCMC sampling from external libraries. E.g. for pyro, this required writing this wrapper (Note: if we add other external MCMC samplers, we should rename this to PyroMCMC) and this interface. One can create a similar interface for other MCMC samplers, see examples above. As suggested in #703 we should add tests that these wrappers sample correctly. A speed comparison could also be useful, as would a FAQ entry on which sampler to use.
good first issue? happy to help crunch this for the hackathon
Probably an advanced good first issue but it does not require knowing sbi really. The main task is to find an interface between our MCMC API and that of the MCMC backend.
See also #986 and #987
solved for pymc via #1053
numpyro and blackjax not planned at the moment because we are based on torch