blackjax icon indicating copy to clipboard operation
blackjax copied to clipboard

Add inverse_mass_matrix to MCLMC

Open reubenharry opened this issue 1 year ago • 3 comments

Current behavior

The MCLMC algorithm in Jakob's original repo has preconditioning (i.e. inverse mass matrix) built into the integrator. In blackjax, there is no option to use preconditioning at all for MCLMC.

Desired behavior

There should be an option to use an inverse mass matrix, just like in HMC. The tuning algorithm should also propose such a matrix.

I should test that this reproduces the results in the original repo.

reubenharry avatar Dec 09 '23 14:12 reubenharry

How about computing the covariance during adaptation of the L step: https://github.com/blackjax-devs/blackjax/blob/40589713d91d4249c4b14a4713232cd73ba34b60/blackjax/adaptation/mclmc_adaptation.py#L279

Since we are storing the samples here for computing ess, computing the cov is straightforward.

Then we just need to add the preconditioning back into the integrator (as optional kwarg)

junpenglao avatar Dec 13 '23 14:12 junpenglao

Yep, I think that's the thing to do in the tuning. There's some code in the original for that purpose.

Re. adding the preconditioning back in, I'm unsure what the best approach is. Surely adding preconditioning is equivalent to changing logdensity_fn, which strikes me as a more direct way of implementing that.

reubenharry avatar Dec 13 '23 16:12 reubenharry

Yeah the two is equivalent, for example TFP go with modifying logdensity_fn. Since in blackjax all other sampler is treating preconditioning on the momentum instead (see kinetic energy and momentum sample generation logic), let's stay with the same approach.

junpenglao avatar Dec 13 '23 16:12 junpenglao