pymc icon indicating copy to clipboard operation
pymc copied to clipboard

Warn on divergences after sampling with JAX

Open jessegrabowski opened this issue 1 year ago • 3 comments

What is this PR about? Closes #7041

This is a draft to add convergence checks to the JAX samplers.

Right now I'm just calling run_convergence_checks after sampling. It might be nice to instead wrap the JAX returns in _sample_return, but this was easier. Feedback requested.

Checklist

Major / Breaking Changes

  • The blackjax sampler should now issue a warning on divergences

New features

  • None

Bugfixes

  • None

Documentation

  • None

Maintenance

  • None

:books: Documentation preview :books:: https://pymc--7051.org.readthedocs.build/en/7051/

jessegrabowski avatar Dec 06 '23 19:12 jessegrabowski

Codecov Report

Merging #7051 (e28c71e) into main (005ba5f) will increase coverage by 0.00%. Report is 1 commits behind head on main. The diff coverage is 100.00%.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #7051   +/-   ##
=======================================
  Coverage   92.16%   92.16%           
=======================================
  Files         101      101           
  Lines       16827    16831    +4     
=======================================
+ Hits        15509    15513    +4     
  Misses       1318     1318           
Files Coverage Δ
pymc/sampling/mcmc.py 87.79% <100.00%> (+0.10%) :arrow_up:

codecov[bot] avatar Dec 06 '23 19:12 codecov[bot]

Looks fine, but why not in the sample_numpyro_nuts function itself? I am thinking of that because some users use that directly and also we already re-implement most of the logic of _sample_return there?

ricardoV94 avatar Dec 07 '23 15:12 ricardoV94

I whipped up this as a live example during the sprint. I'll go back and do a better job following your suggestion @ricardoV94

jessegrabowski avatar Dec 07 '23 18:12 jessegrabowski