blackjax
blackjax copied to clipboard
Current MEADS implementation is incomplete
Current behavior
The MEADS adaptation routine appears to be incomplete. Currently, cross-chain statistics are computed each iteration and used to update the kernel parameters for the entire chain. This is missing some aspects of Algorithm 3 from Hoffman & Sountsov (2022). Perhaps this was on purpose, in which case I would be interested to know why.
meads = blackjax.meads_adaptation(logdensity_fn, num_chains)
Desired behavior
Hoffman & Sountsov (2022) describe an algorithm where the chains are split into $K$ folds. Cross-chains statistics are computed within each fold and used to update the neighbouring fold each iteration (skipping the fold equal to the current iteration modulo $K$). It also describes a shuffling of all chains every $K$ steps. It appears the original author implemented the algorithm in this notebook. I have recently experimented with modifying the BlackJAX MEADS to reflect this for a project testing new MCMC adaptation algorithms.
I propose updating the existing implementation to include the $K$-folding and shuffling described in the paper. This would introduce a few more parameters to blackjax.meads_adaptation which could take default values from the paper.
meads = blackjax.meads_adaptation(logdensity_fn, num_chains, num_folds=4, shuffle=True, step_size_multiplier=0.5, damping_slowdown=1.0)
The step_size_multiplier and damping_slowdown are hyper-parameters used in calculating the MEADS statistics.
@albcab thoughts?
Hi @alexlyttle,
you are right, the current adaptation routine is not exactly the one used in the paper. It was a long time ago, but I think at the end I wanted it to be an adaptation algorithm. MEADS is an adaptive algorithm (as in real adaptive MCMC, which you should never stop adapting and you are sure it will converge to your distribution). Instead, adaptation algorithms are meant to stop adapting and then fix a kernel that you use to generate your samples, a kernel that doesn't change, so you don't need to worry about changing hyperparameters ruining convergence results.
When I wrote the original code, adaptation algorithms would run for a specified number of steps and would return a kernel (GHMC in this case) to fix and use (here is the original code and discussion). Real MEADS would've returned $K$ kernels, which was an API problem. I also opened a pull request to make it more adaptive-ish here, but it was not merged, can't remember why.
Anyway, API has changed, adaptation algorithms have stepping functions and adaptive algorithms can more easily be implemented. So, by all means, go ahead and share/pull your complete adaptive MEADS!
Thanks @albcab, super informative! My follow-up question would have been about how the paper presents MEADS as an adaptive algorithm rather than something where you'd need to freeze the kernel. This makes sense though, I can see your reasoning in the original PR.
I am happy to share my version and contribute if it's welcome.
Practically, should this be added as its own MCMC algorithm in BlackJAX? My understanding of the paper is that MEADS is to GHMC what NUTS is to HMC (in the sense that it is adaptive).
That is a good question, more of a code design/API question. You could either extend the current MEADS implementation in adaptation/ or create a new one in mcmc/. As you said, MEADS is a bit more like NUTS than window adaptation, and there are both hmc.py and nuts.py in mcmc/. But NUTS is kind of special in bayesian computation because of how good it is and thus how widely it is used. Though I'm not sure if you can also adapt the step size as a real adaptive MCMC along with NUTS. If not, that would make MEADS more "complete" than NUTS as adaptive MCMC.
I'm not up to date with code design in the library, so maybe @junpenglao is better for deciding this? I'd lean towards making a new algorithm in mcmc/ like NUTS.
I'm also a little distant from the code here, but from that short distance, it would be neat to have ghmc and an ECA framework, so that implementing MEADS was just... smooshing them together (where "smooshing" means "carefully implementing"). I say ECA would be a "framework" since you would need to figure out how to handle the cross-chain communication, which is different from how I think blackjax normally runs adaptation, but then it should be possible to use for other MCMC step methods as well.
Tagging @siegelordex in case he's got any suggestions/corrections.
@ColCarroll Note that the incoming EMAUS implementation might have some relevant code (see e.g. run_eca), although not sure it's relevant.
+1 what @ColCarroll said about having ECA as its own thing so it can be used with other algorithms.