pymc
pymc copied to clipboard
Improve integration with BlackJAX
We currently interface with the blackjax nuts sampler, but the library has a whole bunch of other inference algorithms, including Pathfinder.
It should be pretty simple to provide an interface so that all algorithms can be run on pymc models with minimal overhead.
Or maybe that's already very simple in which case a good tutorial/example would be nice.
CC @rlouf
I would suggest doing this in pymc-experimental until we figure out what API we need to bridge the libraries.
Also a good hook to get people used to it.
With @ciguaran we have been talking and working on providing integration with BlackJax's SMC. I focused on talking and he focused on working :-) The idea is to eventually replace the current SMC code in pymc with just the necessary bits to interface with BlackJax and return an inferenceData.
Regarding MCMC methods I think we could work on two fronts/APIs 1) an alternative to pm.sample. This should include a within gibbs sampling approach (assuming we want to preserve this feature) and initialization. 2) a series of pm.sample_* that call the different inference methods in BlackJax.
I guess 1 is a must and 2 something nice to have. But I also see why PyMC could choose to stick to only have 1.
I am not sure about replacing altogether, keep in mind that JAX (jitted) can only support a subset of the models that Aesara can.
There is a third way: you can wrap blackjax's algorithms in whatever data structure you need and have users pass this data structure to pm.sample if they want something different from NUTS. MCX was designed with this situation in mind.