pymc
pymc copied to clipboard
Refactor jax internals to support dense_mass kwarg for numpyro
What is this PR about? Enables Block Dense mass matrix adaptation for numpyro
Checklist
- [x] Explain important implementation details 👆
- [x] Make sure that the pre-commit linting/style checks pass.
- [ ] Link relevant issues (preferably in nice commit messages)
- [x] Are the changes covered by tests and docstrings?
- [x] Fill out the short summary sections 👇
Major / Breaking Changes
- ...
New features
- Block mass matrix for numpyro
get_jaxified_logpnow acceptspoint_fnargument
with pm.Model(
coords=dict(level=["Basement", "Floor"], county=[1, 2]),
) as model:
# multilevel modelling
a = pm.Normal("a")
s = pm.HalfNormal("s")
a_g = pm.Normal("a_g", a, s, dims="level")
s_g = pm.HalfNormal("s_g")
a_ig = pm.Normal("a_ig", a_g, s_g, dims=("county", "level"))
trace = sample_numpyro_nuts(
nuts_kwargs=dict(
dense_mass=[
("a", "a_g"),
]
)
)
Bugfixes
- ...
Documentation
- ...
Maintenance
- ...
:books: Documentation preview :books:: https://pymc--7050.org.readthedocs.build/en/7050/
Codecov Report
Merging #7050 (8eb4284) into main (2e05854) will decrease coverage by
12.23%. The diff coverage is0.00%.
Additional details and impacted files
@@ Coverage Diff @@
## main #7050 +/- ##
===========================================
- Coverage 92.19% 79.97% -12.23%
===========================================
Files 101 101
Lines 16893 16911 +18
===========================================
- Hits 15575 13524 -2051
- Misses 1318 3387 +2069
| Files | Coverage Δ | |
|---|---|---|
| pymc/sampling/jax.py | 0.00% <0.00%> (-93.08%) |
:arrow_down: |
These failing tests are definitely a latest PyTensor issue, I'll patch it
Failing tests due to PyTensor should be fixed by https://github.com/pymc-devs/pytensor/pull/546
@ferrine can you rebase?
The rebase did not went as smooth there, converting thit to draft