pymc icon indicating copy to clipboard operation
pymc copied to clipboard

Improve integration with BlackJAX

Open twiecki opened this issue 3 years ago • 4 comments

We currently interface with the blackjax nuts sampler, but the library has a whole bunch of other inference algorithms, including Pathfinder.

It should be pretty simple to provide an interface so that all algorithms can be run on pymc models with minimal overhead.

Or maybe that's already very simple in which case a good tutorial/example would be nice.

CC @rlouf

twiecki avatar Jun 26 '22 09:06 twiecki

I would suggest doing this in pymc-experimental until we figure out what API we need to bridge the libraries.

Also a good hook to get people used to it.

ricardoV94 avatar Jun 26 '22 10:06 ricardoV94

With @ciguaran we have been talking and working on providing integration with BlackJax's SMC. I focused on talking and he focused on working :-) The idea is to eventually replace the current SMC code in pymc with just the necessary bits to interface with BlackJax and return an inferenceData.

Regarding MCMC methods I think we could work on two fronts/APIs 1) an alternative to pm.sample. This should include a within gibbs sampling approach (assuming we want to preserve this feature) and initialization. 2) a series of pm.sample_* that call the different inference methods in BlackJax.

I guess 1 is a must and 2 something nice to have. But I also see why PyMC could choose to stick to only have 1.

aloctavodia avatar Jun 26 '22 22:06 aloctavodia

I am not sure about replacing altogether, keep in mind that JAX (jitted) can only support a subset of the models that Aesara can.

ricardoV94 avatar Jun 27 '22 04:06 ricardoV94

There is a third way: you can wrap blackjax's algorithms in whatever data structure you need and have users pass this data structure to pm.sample if they want something different from NUTS. MCX was designed with this situation in mind.

rlouf avatar Jun 27 '22 05:06 rlouf