blackjax icon indicating copy to clipboard operation
blackjax copied to clipboard

Add the Generalized HMC sampler + ECA adaptation (MEADS)

Open albcab opened this issue 3 years ago • 11 comments

This is a working draft of MEADS (#220). MEADS kernel uses HMC with persistent momentum and MH using nonreversive slice sampling (GHMC) and parallel ensemble chain adaptation. GHMC uses the same flip_momentum function from HMC, maybe that function could be included in the public api of blackjax.mcmc.hmc to be used as a utility function?

I still need to add docs and a better example of its use. The algorithm is most useful when running on multicore CPUs or GPUs.

albcab avatar Jun 24 '22 15:06 albcab

Codecov Report

Merging #229 (013ef29) into main (8386c6f) will decrease coverage by 0.06%. The diff coverage is 97.48%.

@@            Coverage Diff             @@
##             main     #229      +/-   ##
==========================================
- Coverage   98.86%   98.79%   -0.07%     
==========================================
  Files          43       46       +3     
  Lines        1756     1905     +149     
==========================================
+ Hits         1736     1882     +146     
- Misses         20       23       +3     
Impacted Files Coverage Δ
blackjax/__init__.py 100.00% <ø> (ø)
blackjax/mcmc/__init__.py 100.00% <ø> (ø)
blackjax/mcmc/hmc.py 96.61% <80.00%> (-0.06%) :arrow_down:
blackjax/mcmc/ghmc.py 95.45% <95.45%> (ø)
blackjax/adaptation/__init__.py 100.00% <100.00%> (ø)
blackjax/adaptation/chain_adaptation.py 100.00% <100.00%> (ø)
blackjax/adaptation/meads.py 100.00% <100.00%> (ø)
blackjax/kernels.py 99.56% <100.00%> (+0.06%) :arrow_up:
blackjax/mcmc/proposal.py 100.00% <100.00%> (ø)
blackjax/mcmc/trajectory.py 95.72% <100.00%> (ø)

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

codecov[bot] avatar Jun 24 '22 16:06 codecov[bot]

Many great additions in one PR! However I am not a big fan of the warmup being ran inside the kernel; what about users who would like to run Generalized HMC with their parameters? There's a place for this kind of design, but imo blackjax is not where this should happen (PyMC would be, cf https://github.com/pymc-devs/pymc/issues/5930).

MEADS is a warmup scheme for GHMC and should be exposed just as that, an adaptation scheme (just like ChEEs is an adaptation scheme for HMC). GHMC should be exposed as a kernel. What do you think?

I'm asking for @junpenglao's opinion as well.

rlouf avatar Jun 27 '22 06:06 rlouf

+1 to @rlouf's point. For example, in the case of ChEES we will add a HMC kernel that randomized the number of leapforg (Uniform(min, max) that min and max are result from ChEES tuning.

junpenglao avatar Jun 27 '22 06:06 junpenglao

MEADS is a warmup scheme for GHMC and should be exposed just as that, an adaptation scheme (just like ChEEs is an adaptation scheme for HMC). GHMC should be exposed as a kernel. What do you think?

Agree. I've pushed a version with meads and ghmc as separate kernels, the former is an adaptation scheme with run api and the latter a sampling scheme with init and step apis.

I've also recycled code from hmc.py with small modifications to the original code to make it work. I've included Neal's nonreversible slice sampling style MH in proposal.py. I need a version of velocity_verlet that uses a different step_size for each dimension so I've included it as a separate function, I could possibly generalize the original velocity_verlet function or make a include a new general version in integrators.py. I could also include a persistent momentum version of gaussian_euclidean in metrics.py or shift the original version in ghmc.py as I do now. wdyt?

albcab avatar Jul 05 '22 17:07 albcab

I've also recycled code from hmc.py with small modifications to the original code to make it work.

Yes, it looks a lot better, thank you!

I've included Neal's nonreversible slice sampling style MH in proposal.py.

That's a good call.

I need a version of velocity_verlet that uses a different step_size for each dimension.

Very strange to me that they chose to do that instead of using a mass matrix but we'll have to deal with it I guess. I suggest generalizing the existing integrators. What do you think @junpenglao?

I could also include a persistent momentum version of gaussian_euclidean in metrics.py or shift the original version in ghmc.py as I do now. wdyt?

Shfting is great. To make the code easier to read for now you can include the momentum update in a function in ghmc.py:

def update_momentum(rng_key, position, momentum, alpha)
    m, unravel_fn = ravel_pytree(momentum)
    momentum_generator, *_ = metrics.gaussian_euclidean(
        1 / alpha * jnp.ones(jnp.shape(m))
    )
    momentum = jax.tree_map(
        lambda prev_momentum, shifted_momentum: prev_momentum
        * jnp.sqrt(1.0 - alpha)
        + shifted_momentum,
        momentum,
        momentum_generator(rng_key, position),
    )

    return momentum

The code is asking for a generalization but I don't see it right now. I need to read more.

Other comments not in order:

  • kernel_factory is now kernel in adaptation algorithms since we simplified the API.
  • Also did I miss something or did you not fully implement the the warmup?
  • We probably want to vmap in parallel ECA, although we should probably leave the possiblity for the user to choose for instance with a kwarg batch_fn=jax.vmap

Looking forward to playing with the code!

rlouf avatar Jul 06 '22 16:07 rlouf

Very strange to me that they chose to do that instead of using a mass matrix but we'll have to deal with it I guess.

I think its because of the persistent momentum. Assuming initial momentum m has std normal distribution then new momentum m' = sqrt(1-alpha) * m + sqrt(alpha) * e for e std normal, also has std normal distribution (see calculation of MH acceptance rate). Maybe we could scale the momentum after its persistent update before numerical integration, but we would need to adjust the MH acceptance rate calculation to the scaled density.

This could be a nice option since it would avoid generalizing the numerical integrators. Just have to make sure it works theoretically.

albcab avatar Jul 06 '22 17:07 albcab

I think that once you're using tree_multimap that's not a huge deal. Let's see when the warmup is fully implemented!

rlouf avatar Jul 06 '22 18:07 rlouf

Let's see when the warmup is fully implemented!

The warm up or Algorithm 3 of the paper is fully implemented by the meads kernel. The warmup_states returned by run are the samples Algorithm 3 outputs. It also returns a single kernel with parameters estimated using all num_steps samples from batch_size chains and num_batch groups.

  • kernel_factory is now kernel in adaptation algorithms since we simplified the API.

I'm not sure what you mean. All the other adaptation kernels use a function called kernel_factory.

albcab avatar Jul 06 '22 18:07 albcab

Oh yes sorry I did not expect to find the warmup code in kernels.py! I'll need to take another look.

rlouf avatar Jul 06 '22 19:07 rlouf

Would you mind organizing the warmup in adaptation/meads.py so that the following:

    def max_eigen(matrix: PyTree):
        X = jnp.vstack(
            [leaf.T for leaf in jax.tree_leaves(matrix)]
        ).T  # will only work if all variables are at most vectors (not 2+ dimensional tensors)
        n, _ = X.shape
        S = X @ X.T
        diag_S = jnp.diag(S)
        lamda = jnp.sum(diag_S) / n
        lamda_sq = (jnp.sum(S**2) - jnp.sum(diag_S**2)) / (n * (n - 1))
        return lamda_sq / lamda

    def parameter_gn(batch_state, current_iter):
        batch_position = batch_state.position
        mean_position = jax.tree_map(lambda p: p.mean(axis=0), batch_position)
        sd_position = jax.tree_map(lambda p: p.std(axis=0), batch_position)
        batch_norm = jax.tree_map(
            lambda p, mu, sd: (p - mu) / sd,
            batch_position,
            mean_position,
            sd_position,
        )
        batch_grad = jax.pmap(logprob_grad_fn)(batch_position)
        batch_grad_scaled = jax.tree_map(
            lambda grad, sd: grad * sd, batch_grad, sd_position
        )
        epsilon = jnp.minimum(0.5 / jnp.sqrt(max_eigen(batch_grad_scaled)), 1.0)
        gamma = jnp.maximum(
            1.0 / jnp.sqrt(max_eigen(batch_norm)),
            1.0 / ((current_iter + 1) * epsilon),
        )
        alpha = 1.0 - jnp.exp(-2.0 * epsilon * gamma)
        delta = alpha / 2
        step_size = jax.tree_map(lambda sd: epsilon * sd, sd_position)
        return step_size, alpha, delta

is included in this file instead of kernels.py (and merge with what's in eca.py, at this stage there is no reason to separate the two)? kernels.py is supposed to be mostly interface code and plumbing.

rlouf avatar Jul 07 '22 19:07 rlouf

There is still one big change to make, move over all the meads adaptation logic to eca.py (and rename it meads.py) so only the plumbing is left in kernels.py

I've created meads.py as discussed. I've also added a new file called chain_adaptation.py which includes functions that perform parallel ensemble chain adaptation and cross chain adaptation. Since implementing Ch-ESS HMC on blackjax's API for my research, I've noticed that MEADS Algorithm 3 is implemented as an adaptive MCMC algorithm, in the sense that it is meant to generate correct samples as it adapts. For an algorithm like this to be ergodic it needs to follow certain criteria hence why we need parallel ensemble chain adaptation. On the other hand, Ch-ESS Algorithm 1 is not an adaptive algorithm, it is meant to adapt through a warm-up phase (generating incorrect samples), then the algorithm's parameters are fixed and correct samples can be generated. MEADS can also be run like this and not as adaptive MCMC, hence I've given the user the choice (bool parameter eca) of running adaptive MCMC (where samples from warmup_states can be used) or as adaptation returning a ghmc kernel with fixed parameters (same as window_adaptation). We have the choice to implement these options as two separate kernels or as one kernel with a bool parameter. In any case, I'll need cross chain adaptation to implement Ch-ESS, so I've included this in a new file (I also think it'll be useful for other adaptive algorithms).

decide what we do about pmap vs vmap

I've included a batch_fn option for the user and fixed other minor things.

Let me know what you think about the current implementation and if there is something that can be improved. I'm on vacation the coming two weeks but after that I'll add docstrings and tests, finish this PR, open a PR for Ch-ESS and add an example of MEADS and Ch-ESS, probably reproducing the experiments of the MEADS paper or something similar.

albcab avatar Jul 15 '22 01:07 albcab

I've left only the version of MEADS that works like window_adaptation for GHMC. It does change the current API for warm-ups and from various experiments I've found no improvement with the adaptive version.

Also there are some examples using real data that I though could be interesting to have in the documentation.

albcab avatar Aug 25 '22 10:08 albcab

Almost there! Two things related to the notebooks: 1- We can probably skip MEADS.ipynb, it is not adding anything that is not in the two others; 2- The rhat diagnostic takes values greater than 1 in the sparse regression model, which means the chains are not mixing well.

rlouf avatar Sep 05 '22 06:09 rlouf

Great! I've made made the corrections you've recommended and skipped MEADS.ipynb. I've noticed all notebooks are .md now, should I push them as .md files directly? Also, I've increased the R-hat on the sparse regression model by increasing the amount of working iterations (given the high dimension of the target space and the hierarchies, this models tends to be a hard one to sample from). Comparisons with ChESS and other parallel MCMC methods should be interesting on these examples.

albcab avatar Sep 09 '22 15:09 albcab

Great! Yes would you mind converting the notebooks to .md following the instructions in the README? Good on my end otherwise.

@junpenglao do you want to take a look?

rlouf avatar Sep 09 '22 16:09 rlouf

@junpenglao ping

rlouf avatar Sep 15 '22 21:09 rlouf

Per discussion with @rlouf offline, we will merge first and make change of the API with other refactoring effort.

junpenglao avatar Sep 17 '22 06:09 junpenglao

Great job! Thank you for another great contribution and your patience with the process.

rlouf avatar Sep 17 '22 06:09 rlouf