pymc
pymc copied to clipboard
Allow for batched `alpha` in `StickBreakingWeights`
What is this PR about?
Addressing #5383
This enables StickBreakingWeight
's alpha
to accept batched data (>2D), make the infer_shape
work with batched data, and fix the rng_fn
by broadcasting alpha to K.
Checklist
- [x] Explain important implementation details 👆
- [x] Make sure that the pre-commit linting/style checks pass.
- [x] Link relevant issues (preferably in nice commit messages)
- [x] Are the changes covered by tests and docstrings?
- [x] Fill out the short summary sections 👇
Major / Breaking Changes
- None
Bugfixes / New features
-
alpha
now support > 2D data
Docs / Maintenance
- The docstring is updated to include these details.
- Tests have been added.
Codecov Report
Merging #6042 (2a5df64) into main (ad16bf4) will decrease coverage by
1.82%
. The diff coverage is100.00%
.
@@ Coverage Diff @@
## main #6042 +/- ##
==========================================
- Coverage 89.27% 87.44% -1.83%
==========================================
Files 72 72
Lines 12890 12946 +56
==========================================
- Hits 11507 11321 -186
- Misses 1383 1625 +242
Impacted Files | Coverage Δ | |
---|---|---|
pymc/distributions/multivariate.py | 91.25% <100.00%> (-0.75%) |
:arrow_down: |
pymc/distributions/timeseries.py | 43.36% <0.00%> (-35.28%) |
:arrow_down: |
pymc/model_graph.py | 65.66% <0.00%> (-29.80%) |
:arrow_down: |
pymc/model.py | 76.14% <0.00%> (-12.06%) |
:arrow_down: |
pymc/step_methods/hmc/quadpotential.py | 73.76% <0.00%> (-6.94%) |
:arrow_down: |
pymc/util.py | 75.29% <0.00%> (-2.36%) |
:arrow_down: |
pymc/distributions/discrete.py | 97.65% <0.00%> (-1.57%) |
:arrow_down: |
pymc/step_methods/hmc/base_hmc.py | 89.76% <0.00%> (-0.79%) |
:arrow_down: |
pymc/gp/gp.py | 92.73% <0.00%> (-0.45%) |
:arrow_down: |
... and 9 more |
Would the moment
and logp
methods in the distribution class need appropriate broadcasting for vector-valued alpha
s? I am just reading over this PR and might have missed previous discussions
Would the
moment
andlogp
methods in the distribution class need appropriate broadcasting for vector-valuedalpha
s? I am just reading over this PR and might have missed previous discussions
We have tests for batched alpha, but not moment (we should)
We have tests for batched alpha, but not moment (we should)
I need some assistance calculating expected
in test_stickbreakingweights_moment
.
Left a comment above about the fixture. Also don't forget @larryshamalama remark above that we should test the moment
function works for batched alpha as well. The existing tests are in here: https://github.com/pymc-devs/pymc/blob/7af102d40f7b64184cd0fa013ced8570772fc8eb/pymc/tests/test_distributions_moments.py#L1171-L1174
Should be enough to test with a vector of two alphas, maybe one of those that is already tested for single alpha (reusing the same k) and the other being an extreme value like alpha=1 or alpha=0 (if that's valid), which might have a very simple moment.
Yes, I got the test for moment
but I am not sure how the expected
is calculated here.
Is there any equation to determine the expected
?
https://github.com/pymc-devs/pymc/blob/7af102d40f7b64184cd0fa013ced8570772fc8eb/pymc/tests/test_distributions_moments.py#L1147-L1151
Yes, I got the test for
moment
but I am not sure how theexpected
is calculated here. Is there any equation to determine theexpected
?https://github.com/pymc-devs/pymc/blob/7af102d40f7b64184cd0fa013ced8570772fc8eb/pymc/tests/test_distributions_moments.py#L1147-L1151
You can check what the moment is for two distinct single alphas, and it should be the same for a batched alpha that has those two values.
Ok got it now, do I need to create a separate test for batched alpha as we did in TestStickBreakingWeights_1D_alpha
?
Ok got it now, do I need to create a separate test for batched alpha as we did in
TestStickBreakingWeights_1D_alpha
?
Nope, you can just add it as an extra condition in the existing tests. Moments is less sensitive than logp so we can keep it bundled together
Looks complete to me. The failing test is unrelated. Just asked if @larryshamalama could leave a review as well.