Add the Generalized HMC sampler + ECA adaptation (MEADS)
This is a working draft of MEADS (#220). MEADS kernel uses HMC with persistent momentum and MH using nonreversive slice sampling (GHMC) and parallel ensemble chain adaptation. GHMC uses the same flip_momentum function from HMC, maybe that function could be included in the public api of blackjax.mcmc.hmc to be used as a utility function?
I still need to add docs and a better example of its use. The algorithm is most useful when running on multicore CPUs or GPUs.
Codecov Report
Merging #229 (013ef29) into main (8386c6f) will decrease coverage by
0.06%. The diff coverage is97.48%.
@@ Coverage Diff @@
## main #229 +/- ##
==========================================
- Coverage 98.86% 98.79% -0.07%
==========================================
Files 43 46 +3
Lines 1756 1905 +149
==========================================
+ Hits 1736 1882 +146
- Misses 20 23 +3
| Impacted Files | Coverage Δ | |
|---|---|---|
| blackjax/__init__.py | 100.00% <ø> (ø) |
|
| blackjax/mcmc/__init__.py | 100.00% <ø> (ø) |
|
| blackjax/mcmc/hmc.py | 96.61% <80.00%> (-0.06%) |
:arrow_down: |
| blackjax/mcmc/ghmc.py | 95.45% <95.45%> (ø) |
|
| blackjax/adaptation/__init__.py | 100.00% <100.00%> (ø) |
|
| blackjax/adaptation/chain_adaptation.py | 100.00% <100.00%> (ø) |
|
| blackjax/adaptation/meads.py | 100.00% <100.00%> (ø) |
|
| blackjax/kernels.py | 99.56% <100.00%> (+0.06%) |
:arrow_up: |
| blackjax/mcmc/proposal.py | 100.00% <100.00%> (ø) |
|
| blackjax/mcmc/trajectory.py | 95.72% <100.00%> (ø) |
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.
Many great additions in one PR! However I am not a big fan of the warmup being ran inside the kernel; what about users who would like to run Generalized HMC with their parameters? There's a place for this kind of design, but imo blackjax is not where this should happen (PyMC would be, cf https://github.com/pymc-devs/pymc/issues/5930).
MEADS is a warmup scheme for GHMC and should be exposed just as that, an adaptation scheme (just like ChEEs is an adaptation scheme for HMC). GHMC should be exposed as a kernel. What do you think?
I'm asking for @junpenglao's opinion as well.
+1 to @rlouf's point. For example, in the case of ChEES we will add a HMC kernel that randomized the number of leapforg (Uniform(min, max) that min and max are result from ChEES tuning.
MEADSis a warmup scheme for GHMC and should be exposed just as that, an adaptation scheme (just like ChEEs is an adaptation scheme for HMC). GHMC should be exposed as a kernel. What do you think?
Agree. I've pushed a version with meads and ghmc as separate kernels, the former is an adaptation scheme with run api and the latter a sampling scheme with init and step apis.
I've also recycled code from hmc.py with small modifications to the original code to make it work. I've included Neal's nonreversible slice sampling style MH in proposal.py. I need a version of velocity_verlet that uses a different step_size for each dimension so I've included it as a separate function, I could possibly generalize the original velocity_verlet function or make a include a new general version in integrators.py. I could also include a persistent momentum version of gaussian_euclidean in metrics.py or shift the original version in ghmc.py as I do now. wdyt?
I've also recycled code from
hmc.pywith small modifications to the original code to make it work.
Yes, it looks a lot better, thank you!
I've included Neal's nonreversible slice sampling style MH in
proposal.py.
That's a good call.
I need a version of
velocity_verletthat uses a differentstep_sizefor each dimension.
Very strange to me that they chose to do that instead of using a mass matrix but we'll have to deal with it I guess. I suggest generalizing the existing integrators. What do you think @junpenglao?
I could also include a persistent momentum version of
gaussian_euclideaninmetrics.pyor shift the original version inghmc.pyas I do now. wdyt?
Shfting is great. To make the code easier to read for now you can include the momentum update in a function in ghmc.py:
def update_momentum(rng_key, position, momentum, alpha)
m, unravel_fn = ravel_pytree(momentum)
momentum_generator, *_ = metrics.gaussian_euclidean(
1 / alpha * jnp.ones(jnp.shape(m))
)
momentum = jax.tree_map(
lambda prev_momentum, shifted_momentum: prev_momentum
* jnp.sqrt(1.0 - alpha)
+ shifted_momentum,
momentum,
momentum_generator(rng_key, position),
)
return momentum
The code is asking for a generalization but I don't see it right now. I need to read more.
Other comments not in order:
-
kernel_factoryis nowkernelin adaptation algorithms since we simplified the API. - Also did I miss something or did you not fully implement the the warmup?
- We probably want to
vmapin parallel ECA, although we should probably leave the possiblity for the user to choose for instance with a kwargbatch_fn=jax.vmap
Looking forward to playing with the code!
Very strange to me that they chose to do that instead of using a mass matrix but we'll have to deal with it I guess.
I think its because of the persistent momentum. Assuming initial momentum m has std normal distribution then new momentum m' = sqrt(1-alpha) * m + sqrt(alpha) * e for e std normal, also has std normal distribution (see calculation of MH acceptance rate). Maybe we could scale the momentum after its persistent update before numerical integration, but we would need to adjust the MH acceptance rate calculation to the scaled density.
This could be a nice option since it would avoid generalizing the numerical integrators. Just have to make sure it works theoretically.
I think that once you're using tree_multimap that's not a huge deal. Let's see when the warmup is fully implemented!
Let's see when the warmup is fully implemented!
The warm up or Algorithm 3 of the paper is fully implemented by the meads kernel. The warmup_states returned by run are the samples Algorithm 3 outputs. It also returns a single kernel with parameters estimated using all num_steps samples from batch_size chains and num_batch groups.
kernel_factoryis nowkernelin adaptation algorithms since we simplified the API.
I'm not sure what you mean. All the other adaptation kernels use a function called kernel_factory.
Oh yes sorry I did not expect to find the warmup code in kernels.py! I'll need to take another look.
Would you mind organizing the warmup in adaptation/meads.py so that the following:
def max_eigen(matrix: PyTree):
X = jnp.vstack(
[leaf.T for leaf in jax.tree_leaves(matrix)]
).T # will only work if all variables are at most vectors (not 2+ dimensional tensors)
n, _ = X.shape
S = X @ X.T
diag_S = jnp.diag(S)
lamda = jnp.sum(diag_S) / n
lamda_sq = (jnp.sum(S**2) - jnp.sum(diag_S**2)) / (n * (n - 1))
return lamda_sq / lamda
def parameter_gn(batch_state, current_iter):
batch_position = batch_state.position
mean_position = jax.tree_map(lambda p: p.mean(axis=0), batch_position)
sd_position = jax.tree_map(lambda p: p.std(axis=0), batch_position)
batch_norm = jax.tree_map(
lambda p, mu, sd: (p - mu) / sd,
batch_position,
mean_position,
sd_position,
)
batch_grad = jax.pmap(logprob_grad_fn)(batch_position)
batch_grad_scaled = jax.tree_map(
lambda grad, sd: grad * sd, batch_grad, sd_position
)
epsilon = jnp.minimum(0.5 / jnp.sqrt(max_eigen(batch_grad_scaled)), 1.0)
gamma = jnp.maximum(
1.0 / jnp.sqrt(max_eigen(batch_norm)),
1.0 / ((current_iter + 1) * epsilon),
)
alpha = 1.0 - jnp.exp(-2.0 * epsilon * gamma)
delta = alpha / 2
step_size = jax.tree_map(lambda sd: epsilon * sd, sd_position)
return step_size, alpha, delta
is included in this file instead of kernels.py (and merge with what's in eca.py, at this stage there is no reason to separate the two)? kernels.py is supposed to be mostly interface code and plumbing.
There is still one big change to make, move over all the
meadsadaptation logic toeca.py(and rename itmeads.py) so only the plumbing is left inkernels.py
I've created meads.py as discussed. I've also added a new file called chain_adaptation.py which includes functions that perform parallel ensemble chain adaptation and cross chain adaptation. Since implementing Ch-ESS HMC on blackjax's API for my research, I've noticed that MEADS Algorithm 3 is implemented as an adaptive MCMC algorithm, in the sense that it is meant to generate correct samples as it adapts. For an algorithm like this to be ergodic it needs to follow certain criteria hence why we need parallel ensemble chain adaptation. On the other hand, Ch-ESS Algorithm 1 is not an adaptive algorithm, it is meant to adapt through a warm-up phase (generating incorrect samples), then the algorithm's parameters are fixed and correct samples can be generated. MEADS can also be run like this and not as adaptive MCMC, hence I've given the user the choice (bool parameter eca) of running adaptive MCMC (where samples from warmup_states can be used) or as adaptation returning a ghmc kernel with fixed parameters (same as window_adaptation). We have the choice to implement these options as two separate kernels or as one kernel with a bool parameter. In any case, I'll need cross chain adaptation to implement Ch-ESS, so I've included this in a new file (I also think it'll be useful for other adaptive algorithms).
decide what we do about
pmapvsvmap
I've included a batch_fn option for the user and fixed other minor things.
Let me know what you think about the current implementation and if there is something that can be improved. I'm on vacation the coming two weeks but after that I'll add docstrings and tests, finish this PR, open a PR for Ch-ESS and add an example of MEADS and Ch-ESS, probably reproducing the experiments of the MEADS paper or something similar.
I've left only the version of MEADS that works like window_adaptation for GHMC. It does change the current API for warm-ups and from various experiments I've found no improvement with the adaptive version.
Also there are some examples using real data that I though could be interesting to have in the documentation.
Almost there! Two things related to the notebooks:
1- We can probably skip MEADS.ipynb, it is not adding anything that is not in the two others;
2- The rhat diagnostic takes values greater than 1 in the sparse regression model, which means the chains are not mixing well.
Great! I've made made the corrections you've recommended and skipped MEADS.ipynb. I've noticed all notebooks are .md now, should I push them as .md files directly? Also, I've increased the R-hat on the sparse regression model by increasing the amount of working iterations (given the high dimension of the target space and the hierarchies, this models tends to be a hard one to sample from). Comparisons with ChESS and other parallel MCMC methods should be interesting on these examples.
Great! Yes would you mind converting the notebooks to .md following the instructions in the README? Good on my end otherwise.
@junpenglao do you want to take a look?
@junpenglao ping
Per discussion with @rlouf offline, we will merge first and make change of the API with other refactoring effort.
Great job! Thank you for another great contribution and your patience with the process.