pymc icon indicating copy to clipboard operation
pymc copied to clipboard

Fix JAX sampling funcs overwriting existing var's dims and coords

Open jhrcook opened this issue 3 years ago • 1 comments

What is this PR about?

The original solution to fix Issue #5932 overwrites existing dimensions and coordinate data on the variables when creating the ArviZ InferenceData object. (See the issue for a demonstration of this behavior.) This PR address that by using the dims and coords in the idata_kwargs argument to update the extracted dimensions and coordinates.

Checklist

Major / Breaking Changes

  • None

Bugfixes / New features

  • The user-provided dims and coords update instead of overwrite those extracted from the model.

Docs / Maintenance

  • The docstring is updated to include these details.
  • Tests have been added.

jhrcook avatar Aug 09 '22 12:08 jhrcook

Codecov Report

Merging #6041 (3033125) into main (906fcdc) will increase coverage by 0.01%. The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6041      +/-   ##
==========================================
+ Coverage   89.26%   89.27%   +0.01%     
==========================================
  Files          72       72              
  Lines       12890    12897       +7     
==========================================
+ Hits        11506    11514       +8     
+ Misses       1384     1383       -1     
Impacted Files Coverage Δ
pymc/sampling_jax.py 97.15% <100.00%> (+0.09%) :arrow_up:
pymc/step_methods/hmc/base_hmc.py 90.55% <0.00%> (+0.78%) :arrow_up:

codecov[bot] avatar Aug 09 '22 12:08 codecov[bot]

Thanks @jhrcook and @bherwerth!

twiecki avatar Aug 20 '22 08:08 twiecki