pymc4
pymc4 copied to clipboard
Add SMC init support/Separate state and executor
PR adds support for Sequential Monte Carlo (SMC
) method API. SMC
is used to tune the step_size
of NUTS sampler.
SMC support:
- [x] Support
SMC
initialization for mcmc samplerpm.sample(model(), initialize_smc=True/False)
- [x] Provide notebook of Jun's work on SMC (partly done)
- [x] Add comprehensive tests because of the modified logic of the graph executor (this is important)
- [x] Fix issue with
batch_stack
- [x] Add support for
num_chains
(I've runned to some issues but will solve this soon too, this may cause high variance in samples) - [x] Separate SMC and mcmc graph execution logic (will help us with non exposed batch/daws dims in the user model for sMC, also maybe for determining bijectors)
- [ ] ~Add support for other kernels~
Ready to review, but sMC tests are slow (better test logic should be applied)
The PR also separates the implementation of executor logic and Sampling state class implementation. This is done to support SMC state which comes with distinct logp
methods. Also separating executor implementation helps us to remove repetition in the code in the meta_executor
and transformed_executor
.
Check out this pull request on
Review Jupyter notebook visual diffs & provide feedback on notebooks.
Powered by ReviewNB
Thanks @rrkarim, I feel like you are trying to do too much in a single PR, could you separate it into:
- stand alone SMC inference API, this would include the changes on gathering prior and likelihood log_prob
- follow up to add SMC as initialization method
- progress bar support
- mixture distribution support
Sure, I will do it in couple of hours. I'm not sure about the Mixture because it is in progress in different PR. I've just added Mixture for the SMC notebook.
Let me add some tests for sMC and then open a separate PR. batch_stack
is killing me right now. Especially .sample()
method which doesn't care about the plate variable.
Codecov Report
Merging #287 into master will decrease coverage by
0.30%
. The diff coverage is91.03%
.
@@ Coverage Diff @@
## master #287 +/- ##
==========================================
- Coverage 90.89% 90.58% -0.31%
==========================================
Files 36 42 +6
Lines 2822 2954 +132
==========================================
+ Hits 2565 2676 +111
- Misses 257 278 +21
Impacted Files | Coverage Δ | |
---|---|---|
pymc4/inference/utils.py | 51.51% <51.51%> (ø) |
|
pymc4/flow/transformed_executor.py | 90.47% <84.61%> (-2.39%) |
:arrow_down: |
pymc4/flow/smc_executor.py | 85.00% <85.00%> (ø) |
|
pymc4/inference/smc.py | 91.22% <91.22%> (ø) |
|
pymc4/flow/__init__.py | 100.00% <100.00%> (ø) |
|
pymc4/flow/executor.py | 91.71% <100.00%> (-2.82%) |
:arrow_down: |
pymc4/flow/meta_executor.py | 100.00% <100.00%> (+11.90%) |
:arrow_up: |
pymc4/flow/state/__init__.py | 100.00% <100.00%> (ø) |
|
pymc4/flow/state/smc_state.py | 100.00% <100.00%> (ø) |
|
pymc4/flow/state/state.py | 100.00% <100.00%> (ø) |
|
... and 6 more |
Thanks @rrkarim - quick question: is the slow test relate to the autobatching (usage of vmap/make_rank_polymorphic in the log_prob)? Have you compare with a native implementation in TFP?
@junpenglao I kind of forgot about that comment, will test it.