pymc
pymc copied to clipboard
Warn on divergences after sampling with JAX
What is this PR about? Closes #7041
This is a draft to add convergence checks to the JAX samplers.
Right now I'm just calling run_convergence_checks
after sampling. It might be nice to instead wrap the JAX returns in _sample_return
, but this was easier. Feedback requested.
Checklist
- [ ] Explain important implementation details 👆
- [ ] Make sure that the pre-commit linting/style checks pass.
- [ ] Link relevant issues (preferably in nice commit messages)
- [ ] Are the changes covered by tests and docstrings?
- [ ] Fill out the short summary sections 👇
Major / Breaking Changes
- The blackjax sampler should now issue a warning on divergences
New features
- None
Bugfixes
- None
Documentation
- None
Maintenance
- None
:books: Documentation preview :books:: https://pymc--7051.org.readthedocs.build/en/7051/
Codecov Report
Merging #7051 (e28c71e) into main (005ba5f) will increase coverage by
0.00%
. Report is 1 commits behind head on main. The diff coverage is100.00%
.
Additional details and impacted files
@@ Coverage Diff @@
## main #7051 +/- ##
=======================================
Coverage 92.16% 92.16%
=======================================
Files 101 101
Lines 16827 16831 +4
=======================================
+ Hits 15509 15513 +4
Misses 1318 1318
Files | Coverage Δ | |
---|---|---|
pymc/sampling/mcmc.py | 87.79% <100.00%> (+0.10%) |
:arrow_up: |
Looks fine, but why not in the sample_numpyro_nuts function itself? I am thinking of that because some users use that directly and also we already re-implement most of the logic of _sample_return
there?
I whipped up this as a live example during the sprint. I'll go back and do a better job following your suggestion @ricardoV94