pymc
pymc copied to clipboard
Run convergence checks when using JAX samplers
Description
Also:
- Refactored code to avoid so much duplication
- Inlined some very short one-use functions
- Got rid of the many verbose log time statements
- Always run divergence and treedepth checks in
run_convergence_checks
Related Issue
- [x] Closes #7041
Supersedes the following PRs
- [x] Closes #7051
- [x] Closes #7094
Checklist
- [x] Checked that the pre-commit linting/style checks pass
- [x] Included tests that prove the fix is effective or that the new feature works
- [x] Added necessary documentation (docstrings and/or example notebooks)
- [x] If you are a pro: each commit corresponds to a relevant logical change
Type of change
- [x] New feature / enhancement
- [ ] Bug fix
- [ ] Documentation
- [x] Maintenance
- [ ] Other (please specify):
📚 Documentation preview 📚: https://pymc--7165.org.readthedocs.build/en/7165/
Codecov Report
Attention: Patch coverage is 97.75281% with 2 lines in your changes are missing coverage. Please review.
Project coverage is 92.32%. Comparing base (
2e3ea56) to head (17ff331). Report is 1 commits behind head on main.
Additional details and impacted files
@@ Coverage Diff @@
## main #7165 +/- ##
==========================================
+ Coverage 92.29% 92.32% +0.02%
==========================================
Files 101 101
Lines 16947 16904 -43
==========================================
- Hits 15642 15607 -35
+ Misses 1305 1297 -8
| Files | Coverage Δ | |
|---|---|---|
| pymc/sampling/mcmc.py | 87.61% <100.00%> (-0.11%) |
:arrow_down: |
| pymc/stats/convergence.py | 97.43% <100.00%> (+2.65%) |
:arrow_up: |
| pymc/sampling/jax.py | 94.09% <97.40%> (+0.98%) |
:arrow_up: |