pymc
pymc copied to clipboard
Fix JAX sampling funcs overwriting existing var's dims and coords
What is this PR about?
The original solution to fix Issue #5932 overwrites existing dimensions and coordinate data on the variables when creating the ArviZ InferenceData object. (See the issue for a demonstration of this behavior.) This PR address that by using the dims and coords in the idata_kwargs argument to update the extracted dimensions and coordinates.
Checklist
- [x] Explain important implementation details 👆
- [x] Make sure that the pre-commit linting/style checks pass.
- [x] Link relevant issues (preferably in nice commit messages)
- [x] Are the changes covered by tests and docstrings?
- [x] Fill out the short summary sections 👇
Major / Breaking Changes
- None
Bugfixes / New features
- The user-provided
dimsandcoordsupdate instead of overwrite those extracted from the model.
Docs / Maintenance
- The docstring is updated to include these details.
- Tests have been added.
Codecov Report
Merging #6041 (3033125) into main (906fcdc) will increase coverage by
0.01%. The diff coverage is100.00%.
@@ Coverage Diff @@
## main #6041 +/- ##
==========================================
+ Coverage 89.26% 89.27% +0.01%
==========================================
Files 72 72
Lines 12890 12897 +7
==========================================
+ Hits 11506 11514 +8
+ Misses 1384 1383 -1
| Impacted Files | Coverage Δ | |
|---|---|---|
| pymc/sampling_jax.py | 97.15% <100.00%> (+0.09%) |
:arrow_up: |
| pymc/step_methods/hmc/base_hmc.py | 90.55% <0.00%> (+0.78%) |
:arrow_up: |
Thanks @jhrcook and @bherwerth!