pymc icon indicating copy to clipboard operation
pymc copied to clipboard

Allow for batched `alpha` in `StickBreakingWeights`

Open purna135 opened this issue 2 years ago • 1 comments

What is this PR about? Addressing #5383 This enables StickBreakingWeight's alpha to accept batched data (>2D), make the infer_shape work with batched data, and fix the rng_fn by broadcasting alpha to K.

Checklist

Major / Breaking Changes

  • None

Bugfixes / New features

  • alpha now support > 2D data

Docs / Maintenance

  • The docstring is updated to include these details.
  • Tests have been added.

purna135 avatar Aug 09 '22 19:08 purna135

Codecov Report

Merging #6042 (2a5df64) into main (ad16bf4) will decrease coverage by 1.82%. The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6042      +/-   ##
==========================================
- Coverage   89.27%   87.44%   -1.83%     
==========================================
  Files          72       72              
  Lines       12890    12946      +56     
==========================================
- Hits        11507    11321     -186     
- Misses       1383     1625     +242     
Impacted Files Coverage Δ
pymc/distributions/multivariate.py 91.25% <100.00%> (-0.75%) :arrow_down:
pymc/distributions/timeseries.py 43.36% <0.00%> (-35.28%) :arrow_down:
pymc/model_graph.py 65.66% <0.00%> (-29.80%) :arrow_down:
pymc/model.py 76.14% <0.00%> (-12.06%) :arrow_down:
pymc/step_methods/hmc/quadpotential.py 73.76% <0.00%> (-6.94%) :arrow_down:
pymc/util.py 75.29% <0.00%> (-2.36%) :arrow_down:
pymc/distributions/discrete.py 97.65% <0.00%> (-1.57%) :arrow_down:
pymc/step_methods/hmc/base_hmc.py 89.76% <0.00%> (-0.79%) :arrow_down:
pymc/gp/gp.py 92.73% <0.00%> (-0.45%) :arrow_down:
... and 9 more

codecov[bot] avatar Aug 09 '22 20:08 codecov[bot]

Would the moment and logp methods in the distribution class need appropriate broadcasting for vector-valued alphas? I am just reading over this PR and might have missed previous discussions

larryshamalama avatar Aug 16 '22 17:08 larryshamalama

Would the moment and logp methods in the distribution class need appropriate broadcasting for vector-valued alphas? I am just reading over this PR and might have missed previous discussions

We have tests for batched alpha, but not moment (we should)

ricardoV94 avatar Aug 16 '22 19:08 ricardoV94

We have tests for batched alpha, but not moment (we should)

I need some assistance calculating expected in test_stickbreakingweights_moment.

purna135 avatar Aug 18 '22 10:08 purna135

Left a comment above about the fixture. Also don't forget @larryshamalama remark above that we should test the moment function works for batched alpha as well. The existing tests are in here: https://github.com/pymc-devs/pymc/blob/7af102d40f7b64184cd0fa013ced8570772fc8eb/pymc/tests/test_distributions_moments.py#L1171-L1174

Should be enough to test with a vector of two alphas, maybe one of those that is already tested for single alpha (reusing the same k) and the other being an extreme value like alpha=1 or alpha=0 (if that's valid), which might have a very simple moment.

ricardoV94 avatar Aug 25 '22 09:08 ricardoV94

Yes, I got the test for moment but I am not sure how the expected is calculated here. Is there any equation to determine the expected?

https://github.com/pymc-devs/pymc/blob/7af102d40f7b64184cd0fa013ced8570772fc8eb/pymc/tests/test_distributions_moments.py#L1147-L1151

purna135 avatar Aug 25 '22 10:08 purna135

Yes, I got the test for moment but I am not sure how the expected is calculated here. Is there any equation to determine the expected?

https://github.com/pymc-devs/pymc/blob/7af102d40f7b64184cd0fa013ced8570772fc8eb/pymc/tests/test_distributions_moments.py#L1147-L1151

You can check what the moment is for two distinct single alphas, and it should be the same for a batched alpha that has those two values.

ricardoV94 avatar Aug 25 '22 10:08 ricardoV94

Ok got it now, do I need to create a separate test for batched alpha as we did in TestStickBreakingWeights_1D_alpha ?

purna135 avatar Aug 25 '22 10:08 purna135

Ok got it now, do I need to create a separate test for batched alpha as we did in TestStickBreakingWeights_1D_alpha ?

Nope, you can just add it as an extra condition in the existing tests. Moments is less sensitive than logp so we can keep it bundled together

ricardoV94 avatar Aug 25 '22 11:08 ricardoV94

Looks complete to me. The failing test is unrelated. Just asked if @larryshamalama could leave a review as well.

ricardoV94 avatar Aug 29 '22 10:08 ricardoV94