pymc
pymc copied to clipboard
Add `ZeroSumNormal` distribution π₯
This PR introduces the world famous ZeroSumNormal distribution, i.e a Normal distribution where one or several axes are constrained to sum to zero. By default, the last axis is constrained to sum to zero.
β οΈ sigma has to be a scalar, to ensure the zero-sum constraint. The ability to specifiy a vector of sigma may be added in future versions.
Checklist
- [x] Explain important implementation details π
- [x] Make sure that the pre-commit linting/style checks pass.
- [x] Docstrings
- [x] Tests
Examples:
COORDS = {
"regions": ["a", "b", "c"],
"answers": ["yes", "no", "whatever", "don't understand question"],
}
with pm.Model(coords=COORDS) as m:
v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes="answers")
with pm.Model(coords=COORDS) as m:
v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=("regions", "answers"))
with pm.Model(coords=COORDS) as m:
v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=1)
Codecov Report
Merging #6121 (5954e65) into main (e419d53) will increase coverage by
0.35%. The diff coverage is97.34%.
:exclamation: Current head 5954e65 differs from pull request most recent head 3e72922. Consider uploading reports for the commit 3e72922 to get more accurate results
Additional details and impacted files
@@ Coverage Diff @@
## main #6121 +/- ##
==========================================
+ Coverage 93.05% 93.40% +0.35%
==========================================
Files 91 100 +9
Lines 20804 22138 +1334
==========================================
+ Hits 19360 20679 +1319
- Misses 1444 1459 +15
| Impacted Files | Coverage Ξ | |
|---|---|---|
| pymc/distributions/multivariate.py | 92.24% <91.54%> (-0.07%) |
:arrow_down: |
| pymc/distributions/shape_utils.py | 98.66% <97.14%> (-0.29%) |
:arrow_down: |
| pymc/tests/distributions/test_multivariate.py | 99.44% <98.97%> (-0.06%) |
:arrow_down: |
| pymc/distributions/timeseries.py | 82.56% <100.00%> (-0.96%) |
:arrow_down: |
| pymc/distributions/transforms.py | 100.00% <100.00%> (ΓΈ) |
|
| pymc/tests/distributions/test_shape_utils.py | 99.76% <100.00%> (+0.03%) |
:arrow_up: |
| pymc/tests/distributions/test_timeseries.py | 95.45% <100.00%> (-0.34%) |
:arrow_down: |
| pymc/parallel_sampling.py | 85.80% <0.00%> (-1.00%) |
:arrow_down: |
| pymc/tests/distributions/test_truncated.py | 99.48% <0.00%> (-0.52%) |
:arrow_down: |
| pymc/data.py | 80.08% <0.00%> (ΓΈ) |
|
| ... and 13 more |
Thanks for the first review @ricardoV94 ! I added some tests. How do they look like?
This would make more sense in multivariate even though we are treating it as a scalar as a hack for the time being.
Good point. I moved it to multivariate.py.
I also added all the tests mentioned and pushed everything. They pass locally. Let's see if they pass here π€
There is a weird error on the ubuntu tests:
> self.step_size = step_scale / (size**0.25)
E ZeroDivisionError: float division by zero
It's not even coming from ZSN code, but from NUTS π€¨ Anybody knows what's that about? Most probably it's a platform issue, because I can't replicate that, and the tests pass for the other platforms.
There is also a failing test on Windows, but that's definitely unrelated to this PR:
> npt.assert_allclose(np.asarray(self.centers), np.sort(Xu), atol=0.1)
E AssertionError:
E Not equal to tolerance rtol=1e-07, atol=0.1
E
E Mismatched elements: 1 / 2 (50%)
E Max absolute difference: 0.10133194
E Max relative difference: 0.02068561
E x: array([-5, 5])
E y: array([-4.898668, 4.96629 ])
@lucianopaz @ricardoV94 you can re-review. I'll try to add the test against MvNormal logp this weekend, but other than that this PR should be complete π€