blackjax
blackjax copied to clipboard
Add new parameter tuning for SMC
Thank you for opening a PR!
A few important guidelines and requirements before we can merge your PR:
- [N/A] If I add a new sampler, there is an issue discussing it already;
- [x] We should be able to understand what the PR does from its title only;
- [x] There is a high-level description of the changes;
- Implements the possibility of tuning SMC kernels based on particles.
- As a subcase of (1) allows for tuning a proposal distribution.
- As a subcase of (2) allows to tune a multivariate normal distribution.
-
[x] There are links to all the relevant issues, discussions and PRs; See step 2 of: https://github.com/blackjax-devs/blackjax/issues/245
-
[x] The branch is rebased on the latest
main
commit; -
[x] Commit messages follow these guidelines;
-
[x] The code respects the current naming conventions;
-
[x] Docstrings follow the numpy style guide
-
[x]
pre-commit
is installed and configured on your machine, and you ran it before opening the PR; -
[x] There are tests covering the changes;
-
[x] The doc is up-to-date;
-
[N/a] If I add a new sampler* I added/updated related examples
Consider opening a Draft PR if your work is still in progress but you would like some feedback from other contributors. @rlouf @aloctavodia
Codecov Report
Merging #261 (79382b6) into main (d440031) will increase coverage by
0.02%
. The diff coverage is100.00%
.
@@ Coverage Diff @@
## main #261 +/- ##
==========================================
+ Coverage 98.63% 98.65% +0.02%
==========================================
Files 43 45 +2
Lines 1757 1789 +32
==========================================
+ Hits 1733 1765 +32
Misses 24 24
Impacted Files | Coverage Δ | |
---|---|---|
blackjax/kernels.py | 99.49% <100.00%> (-0.01%) |
:arrow_down: |
blackjax/smc/base.py | 100.00% <100.00%> (ø) |
|
blackjax/smc/ess.py | 100.00% <100.00%> (ø) |
|
blackjax/smc/parameter_tuning.py | 100.00% <100.00%> (ø) |
|
blackjax/smc/particle_utils.py | 100.00% <100.00%> (ø) |
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.
@junpenglao beware that this PR adds new tuning functionality, not only refactors.
I see, in that case maybe it is better to split adding the functionary with the refactoring?
I think we should think a bit more of introducing the typing of LogProbFn
and consider the pros and cons, which might be a bit easier to do in a separate PR.
@junpenglao done. removed LogProbFn from this one.
@rlouf FYI this PR is ready for review/commenting
Thank you. I've been travelling/working the whole month, will take a look whenever I have a chance.
The current design adds a coupling between sampling and adaptation that doesn't need to be. It does exist to some extent with NUTS and the window adaptation, but AeHMC for instance has completely decoupled the act of sampling from the act of updating parameters and this is what we're headed towards for blackjax==1.0.0
.
This is what we need to implement. On the one hand adaptation algorithms that take a set of particles and return values for the parameters, on the other hand SMC samplers that take a set of particles, parameters and returns a new set of particles.
I will make it easier to decouple sampling and updating parameters in #276 . This should simplify the code a lot.
Hi @ciguaran would you like to pick this up again?
@junpenglao yes. Stay tuned
Closing this since https://github.com/blackjax-devs/blackjax/pull/595 implements the functionality described