blackjax icon indicating copy to clipboard operation
blackjax copied to clipboard

Add the waste free SMC sampler

Open AdrienCorenflos opened this issue 3 years ago • 13 comments

This PR is to implement waste-free SMC (see https://arxiv.org/abs/2011.02328) invented by @hai-dang-dau and @nchopin, see https://github.com/blackjax-devs/blackjax/discussions/87. I decided to pick up this item after the latest learnbayesstats episode release where SMC was mentioned again (@AlexAndorra :wave:).

So far it seems to be working apart from the normalizing constant unittest for dim=10. I'd appreciate if the two paper authors could share some views so as to what is happening.

Roadmap to merging

  • [x] Understand why the unittest fails in dim=10
  • [x] Adapt the tempered SMC notebook example.
  • [x] Fix style issues

AdrienCorenflos avatar Sep 22 '21 20:09 AdrienCorenflos

Opened issue https://github.com/blackjax-devs/blackjax/issues/118

AdrienCorenflos avatar Sep 22 '21 21:09 AdrienCorenflos

Codecov Report

Merging #117 (388acb1) into main (6f728a9) will increase coverage by 0.03%. The diff coverage is 100.00%.

:exclamation: Current head 388acb1 differs from pull request most recent head aae50d4. Consider uploading reports for the commit aae50d4 to get more accurate results

@@            Coverage Diff             @@
##             main     #117      +/-   ##
==========================================
+ Coverage   98.56%   98.60%   +0.03%     
==========================================
  Files          25       25              
  Lines         977     1004      +27     
==========================================
+ Hits          963      990      +27     
  Misses         14       14              
Impacted Files Coverage Δ
blackjax/inference/smc/base.py 100.00% <100.00%> (ø)
blackjax/inference/smc/resampling.py 100.00% <100.00%> (ø)
blackjax/tempered_smc.py 100.00% <100.00%> (ø)
blackjax/diagnostics.py 100.00% <0.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update 6f728a9...aae50d4. Read the comment docs.

codecov[bot] avatar Sep 23 '21 12:09 codecov[bot]

Ready for review. For some reason the precommit thingy does not correct the notebook. I am not sure where exactly there is a pb with it...

AdrienCorenflos avatar Sep 23 '21 12:09 AdrienCorenflos

For some reason the precommit thingy does not correct the notebook. I am not sure where exactly there is a pb with it...

Essentially it says it fixed it (not sure what it's fixing though) but then it ends up not changing the file :man_shrugging:

Screenshot from 2021-09-23 15-53-25

AdrienCorenflos avatar Sep 23 '21 12:09 AdrienCorenflos

Also made residual resampling allow for sub and supersampling while at it.

AdrienCorenflos avatar Sep 23 '21 18:09 AdrienCorenflos

Can you relax style requirements for notebook? I ran the pre-commit locally, it reformatted the notebook, but it's not working here.

reformatted notebooks/TemperedSMC.ipynb All done! ✨ 🍰 ✨ 1 file reformatted, 5 files left unchanged.

I have to manually go in and add a new line using a text editor every time.

AdrienCorenflos avatar Oct 04 '21 17:10 AdrienCorenflos

@MarcoGorelli what could be causing that?

rlouf avatar Oct 04 '21 17:10 rlouf

Hey - I'm currently on holiday, I can take a closer look next week when I'm back

You can use exclude to not run end-of-file-fixer on notebooks - is that the only hook that's causing issues?

Ps I contributed nbqa-black upstream up black, I'd suggest using their black-jupyter hook directly

MarcoGorelli avatar Oct 04 '21 17:10 MarcoGorelli

@AdrienCorenflos Alright I rebased your branch on the main branch (fixed many merge conflicts) and squashed the commits. Two things before merging:

  • I don't like the kwarg and branching logic inside the smc kernel. I need to re-read the paper (https://arxiv.org/abs/2011.02328) and see if we can somehow refactor this nicely;
  • We should find an example where the waste-free version clearly outperforms vanilla SMC (maybe one from the paper?)

rlouf avatar Dec 27 '21 19:12 rlouf

The paper we wrote with @hai-dang-dau on waste-free SMC (here) contains 3 numerical experiments; in each experiment, waste-free outperforms vanilla SMC. I particularly like the one on counting Latin squares. However, it is less relevant to statistical inference, and the Metropolis step is non-standard (since the sampling space is the space of k x k matrices where each row is a permutation of 1, ..., k). The first example on logistic regression is a bit more pedestrian, but it should be less hassle. My two cents! :-)

nchopin avatar Jan 07 '22 15:01 nchopin

I think I've figured out what bothered me. The base SMC kernel was designed as a mere sampler when it is in fact much more general. Reading your book more carefully helped @nchopin! I will try to merge this when I'm done working on #279.

My only concern is memory usage. But JAX code is compiled with XLA which I hope is smart enough to not allocate the memory when only the last particle of each chain is used for resampling. I will need to make some tests.

rlouf avatar Jul 03 '22 12:07 rlouf

Well, I'm happy that you found the book useful! :-) So what you're saying by "more general" than a "mere sampler" is that you want to be able to approximate also all the intermediate distributions, and this will be expensive, memory-wise, because you need to save the particles (and their weights) at each iteration t?

nchopin avatar Sep 19 '22 07:09 nchopin

So what you're saying by "more general" than a "mere sampler"

This statement is about my former mental representation of SMC, which comes from existing probabilistic languages. They tend to have one implementation per application. But the generality of the Feynman-Kac model (good memories from my past physicist life!), and made me realize that we can write a much more modular implementation (roughly Algorithm 10.1) that has all these specific implementations as particular cases. This may seem like splitting hair for many, but implementing the conceptually correct structure for the HMC family instead of multiplying the particular cases has paid dividends.

is that you want to be able to approximate also all the intermediate distributions, and this will be expensive, memory-wise, because you need to save the particles (and their weights) at each iteration t?

Well, in all generality we need to save the particles at each iteration, otherwise algorithms like waste-free SMC will need their own implementation. My concern is not rooted in any experience, and I don't have the memory footprint of one particle in mind to do back-of-the-envelope calculation, so I might be worrying for no reason. I am fine "particular casing" waste-free SMC if need be, but only if I'm sure the alternative will be problematic.

Anyway, this is very interesting. The waste-free approach collides with some experimental work I've done on decoupling the integration steps from the sample generation in HMC (recycled HMC, but without having to resample from history), and I hope to be able to use this "perpetual HMC" algorithm with a variant of waste-free SMC sampler. If that makes any sense.

rlouf avatar Sep 19 '22 08:09 rlouf

I am closing since the refactor in #279 makes it now possible to build waste-free samplers.

rlouf avatar Jan 13 '23 13:01 rlouf