blackjax
blackjax copied to clipboard
Add the waste free SMC sampler
This PR is to implement waste-free SMC (see https://arxiv.org/abs/2011.02328) invented by @hai-dang-dau and @nchopin, see https://github.com/blackjax-devs/blackjax/discussions/87. I decided to pick up this item after the latest learnbayesstats episode release where SMC was mentioned again (@AlexAndorra :wave:).
So far it seems to be working apart from the normalizing constant unittest for dim=10
. I'd appreciate if the two paper authors could share some views so as to what is happening.
Roadmap to merging
- [x] Understand why the unittest fails in
dim=10
- [x] Adapt the tempered SMC notebook example.
- [x] Fix style issues
Opened issue https://github.com/blackjax-devs/blackjax/issues/118
Codecov Report
Merging #117 (388acb1) into main (6f728a9) will increase coverage by
0.03%
. The diff coverage is100.00%
.
:exclamation: Current head 388acb1 differs from pull request most recent head aae50d4. Consider uploading reports for the commit aae50d4 to get more accurate results
@@ Coverage Diff @@
## main #117 +/- ##
==========================================
+ Coverage 98.56% 98.60% +0.03%
==========================================
Files 25 25
Lines 977 1004 +27
==========================================
+ Hits 963 990 +27
Misses 14 14
Impacted Files | Coverage Δ | |
---|---|---|
blackjax/inference/smc/base.py | 100.00% <100.00%> (ø) |
|
blackjax/inference/smc/resampling.py | 100.00% <100.00%> (ø) |
|
blackjax/tempered_smc.py | 100.00% <100.00%> (ø) |
|
blackjax/diagnostics.py | 100.00% <0.00%> (ø) |
Continue to review full report at Codecov.
Legend - Click here to learn more
Δ = absolute <relative> (impact)
,ø = not affected
,? = missing data
Powered by Codecov. Last update 6f728a9...aae50d4. Read the comment docs.
Ready for review. For some reason the precommit thingy does not correct the notebook. I am not sure where exactly there is a pb with it...
For some reason the precommit thingy does not correct the notebook. I am not sure where exactly there is a pb with it...
Essentially it says it fixed it (not sure what it's fixing though) but then it ends up not changing the file :man_shrugging:
Also made residual resampling allow for sub and supersampling while at it.
Can you relax style requirements for notebook? I ran the pre-commit locally, it reformatted the notebook, but it's not working here.
reformatted notebooks/TemperedSMC.ipynb All done! ✨ 🍰 ✨ 1 file reformatted, 5 files left unchanged.
I have to manually go in and add a new line using a text editor every time.
@MarcoGorelli what could be causing that?
Hey - I'm currently on holiday, I can take a closer look next week when I'm back
You can use exclude
to not run end-of-file-fixer
on notebooks - is that the only hook that's causing issues?
Ps I contributed nbqa-black upstream up black, I'd suggest using their black-jupyter hook directly
@AdrienCorenflos Alright I rebased your branch on the main branch (fixed many merge conflicts) and squashed the commits. Two things before merging:
- I don't like the kwarg and branching logic inside the
smc
kernel. I need to re-read the paper (https://arxiv.org/abs/2011.02328) and see if we can somehow refactor this nicely; - We should find an example where the waste-free version clearly outperforms vanilla SMC (maybe one from the paper?)
The paper we wrote with @hai-dang-dau on waste-free SMC (here) contains 3 numerical experiments; in each experiment, waste-free outperforms vanilla SMC. I particularly like the one on counting Latin squares. However, it is less relevant to statistical inference, and the Metropolis step is non-standard (since the sampling space is the space of k x k matrices where each row is a permutation of 1, ..., k). The first example on logistic regression is a bit more pedestrian, but it should be less hassle. My two cents! :-)
I think I've figured out what bothered me. The base SMC kernel was designed as a mere sampler when it is in fact much more general. Reading your book more carefully helped @nchopin! I will try to merge this when I'm done working on #279.
My only concern is memory usage. But JAX code is compiled with XLA which I hope is smart enough to not allocate the memory when only the last particle of each chain is used for resampling. I will need to make some tests.
Well, I'm happy that you found the book useful! :-) So what you're saying by "more general" than a "mere sampler" is that you want to be able to approximate also all the intermediate distributions, and this will be expensive, memory-wise, because you need to save the particles (and their weights) at each iteration t?
So what you're saying by "more general" than a "mere sampler"
This statement is about my former mental representation of SMC, which comes from existing probabilistic languages. They tend to have one implementation per application. But the generality of the Feynman-Kac model (good memories from my past physicist life!), and made me realize that we can write a much more modular implementation (roughly Algorithm 10.1) that has all these specific implementations as particular cases. This may seem like splitting hair for many, but implementing the conceptually correct structure for the HMC family instead of multiplying the particular cases has paid dividends.
is that you want to be able to approximate also all the intermediate distributions, and this will be expensive, memory-wise, because you need to save the particles (and their weights) at each iteration t?
Well, in all generality we need to save the particles at each iteration, otherwise algorithms like waste-free SMC will need their own implementation. My concern is not rooted in any experience, and I don't have the memory footprint of one particle in mind to do back-of-the-envelope calculation, so I might be worrying for no reason. I am fine "particular casing" waste-free SMC if need be, but only if I'm sure the alternative will be problematic.
Anyway, this is very interesting. The waste-free approach collides with some experimental work I've done on decoupling the integration steps from the sample generation in HMC (recycled HMC, but without having to resample from history), and I hope to be able to use this "perpetual HMC" algorithm with a variant of waste-free SMC sampler. If that makes any sense.
I am closing since the refactor in #279 makes it now possible to build waste-free samplers.