blackjax
blackjax copied to clipboard
SMC: optimizations
Hi all, in the context of SMC it makes sense to have an independent version of RMH. Independent in this case stands for a proposal distribution that doesn't depend on the particle being mutated, but on the whole particle population. For example, the proposal could be a Multivariate Normal Distribution whose covariance matrix is computed from the sample covariance of the particle population.
Moreover, there are optimizations that can be made, for example mutating a particle several times before resampling.
I am planning on implementing these in Blackjax and wanted to validate the approach. The proposed PR order could be:
- Implement and test an IRMH sampler with no optimization.
- Implement tuning based on particles properties
- Implement several-mutations-in-one-step optimization
- Implement an example in the examples folder that uses this.
I am proposing implementing 4. only after 1-3 because the algorithm alone is not powerful enough to be of use (at least AFAIK) without the implementations in 2-3 and without being an SMC kernel. LMK if you have any suggestions on a different order, since merging 1. would mean implementing a new sampler without an example (just tests on very simple models).
@rlouf @aloctavodia
Thank you for opening an issue! Is there anything I could read about this? Do you have a design to propose yet?
https://www.jstor.org/stable/4140600?seq=1, Section 4.2. I didn't think this was still considered a valid approach in the continuous case, but maybe I'm wrong. It's perhaps more interesting in the discrete case: https://arxiv.org/pdf/1101.6037.pdf
Just for reference, this is the default kernel in PyMC for SMC and SMC-ABC
Let's see this design!
Hio! not sure if you have any structure for a design document in mind. Let me know if you prefer me to extend the following description with any diagram or explain further:
For 1. My plan is to add a new kernel in kernels.py and test it similarly to what's currently done in test_sampling.py. I can have a draft PR for this step soon.
For 2 I was planning to extract-and-extend from https://github.com/blackjax-devs/blackjax/blob/main/blackjax/kernels.py#L59 having a no-tunning default equal to what happens to today, and the user may choose other tunning strategies. Tunning strategies end up generating a kernel_factory which is then reused in each step of the base SMC execution and depend on the particle population. This will also later allow for tunning other kernels such as HMC in the context of SMC.
-
Will probably be some decoration on the base IMH, but not sure on this one just yet.
-
Will use a Gaussian Mixture model as initial example.