pymc icon indicating copy to clipboard operation
pymc copied to clipboard

Add `ZeroSumNormal` distribution πŸ”₯

Open AlexAndorra opened this issue 3 years ago β€’ 4 comments

This PR introduces the world famous ZeroSumNormal distribution, i.e a Normal distribution where one or several axes are constrained to sum to zero. By default, the last axis is constrained to sum to zero.

⚠️ sigma has to be a scalar, to ensure the zero-sum constraint. The ability to specifiy a vector of sigma may be added in future versions.

Checklist


Examples:

COORDS = {
    "regions": ["a", "b", "c"],
     "answers": ["yes", "no", "whatever", "don't understand question"],
}
with pm.Model(coords=COORDS) as m:
    v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes="answers")

with pm.Model(coords=COORDS) as m:
    v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=("regions", "answers"))

with pm.Model(coords=COORDS) as m:
    v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=1)

AlexAndorra avatar Sep 12 '22 14:09 AlexAndorra

Codecov Report

Merging #6121 (5954e65) into main (e419d53) will increase coverage by 0.35%. The diff coverage is 97.34%.

:exclamation: Current head 5954e65 differs from pull request most recent head 3e72922. Consider uploading reports for the commit 3e72922 to get more accurate results

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6121      +/-   ##
==========================================
+ Coverage   93.05%   93.40%   +0.35%     
==========================================
  Files          91      100       +9     
  Lines       20804    22138    +1334     
==========================================
+ Hits        19360    20679    +1319     
- Misses       1444     1459      +15     
Impacted Files Coverage Ξ”
pymc/distributions/multivariate.py 92.24% <91.54%> (-0.07%) :arrow_down:
pymc/distributions/shape_utils.py 98.66% <97.14%> (-0.29%) :arrow_down:
pymc/tests/distributions/test_multivariate.py 99.44% <98.97%> (-0.06%) :arrow_down:
pymc/distributions/timeseries.py 82.56% <100.00%> (-0.96%) :arrow_down:
pymc/distributions/transforms.py 100.00% <100.00%> (ΓΈ)
pymc/tests/distributions/test_shape_utils.py 99.76% <100.00%> (+0.03%) :arrow_up:
pymc/tests/distributions/test_timeseries.py 95.45% <100.00%> (-0.34%) :arrow_down:
pymc/parallel_sampling.py 85.80% <0.00%> (-1.00%) :arrow_down:
pymc/tests/distributions/test_truncated.py 99.48% <0.00%> (-0.52%) :arrow_down:
pymc/data.py 80.08% <0.00%> (ΓΈ)
... and 13 more

codecov[bot] avatar Sep 12 '22 15:09 codecov[bot]

Thanks for the first review @ricardoV94 ! I added some tests. How do they look like?

AlexAndorra avatar Sep 12 '22 15:09 AlexAndorra

This would make more sense in multivariate even though we are treating it as a scalar as a hack for the time being.

Good point. I moved it to multivariate.py.

I also added all the tests mentioned and pushed everything. They pass locally. Let's see if they pass here 🀞

AlexAndorra avatar Sep 15 '22 13:09 AlexAndorra

There is a weird error on the ubuntu tests:

>       self.step_size = step_scale / (size**0.25)
E       ZeroDivisionError: float division by zero

It's not even coming from ZSN code, but from NUTS 🀨 Anybody knows what's that about? Most probably it's a platform issue, because I can't replicate that, and the tests pass for the other platforms.

There is also a failing test on Windows, but that's definitely unrelated to this PR:

>       npt.assert_allclose(np.asarray(self.centers), np.sort(Xu), atol=0.1)
E       AssertionError: 
E       Not equal to tolerance rtol=1e-07, atol=0.1
E       
E       Mismatched elements: 1 / 2 (50%)
E       Max absolute difference: 0.10133194
E       Max relative difference: 0.02068561
E        x: array([-5,  5])
E        y: array([-4.898668,  4.96629 ])

AlexAndorra avatar Sep 18 '22 14:09 AlexAndorra

@lucianopaz @ricardoV94 you can re-review. I'll try to add the test against MvNormal logp this weekend, but other than that this PR should be complete 🀞

AlexAndorra avatar Sep 30 '22 14:09 AlexAndorra