pymc icon indicating copy to clipboard operation
pymc copied to clipboard

Incorporating AePPL into `GaussianRandomWalk`

Open larryshamalama opened this issue 2 years ago • 21 comments

Closes #5762.

This PR removes GaussianRandomWalkRV altogether by defining an rv_op as a cumulative sum of distributions. The most recent version of AePPL, i.e. 0.0.31, is now able to retrieve the appropriate logp graph from a CumOp.

This is a WIP for now as I figure out how to use AePPL...

larryshamalama avatar May 28 '22 02:05 larryshamalama

Codecov Report

Merging #5814 (9a7badd) into main (906fcdc) will decrease coverage by 6.24%. The diff coverage is 93.18%.

:exclamation: Current head 9a7badd differs from pull request most recent head dbeb801. Consider uploading reports for the commit dbeb801 to get more accurate results

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #5814      +/-   ##
==========================================
- Coverage   89.26%   83.01%   -6.25%     
==========================================
  Files          72       73       +1     
  Lines       12890    13261     +371     
==========================================
- Hits        11506    11009     -497     
- Misses       1384     2252     +868     
Impacted Files Coverage Δ
pymc/distributions/timeseries.py 77.55% <92.85%> (-1.09%) :arrow_down:
pymc/distributions/logprob.py 89.55% <100.00%> (-8.18%) :arrow_down:
pymc/func_utils.py 22.50% <0.00%> (-65.88%) :arrow_down:
pymc/distributions/bound.py 45.54% <0.00%> (-54.46%) :arrow_down:
pymc/distributions/mixture.py 59.35% <0.00%> (-36.37%) :arrow_down:
pymc/exceptions.py 68.96% <0.00%> (-31.04%) :arrow_down:
pymc/distributions/discrete.py 71.35% <0.00%> (-27.87%) :arrow_down:
pymc/tuning/scaling.py 35.18% <0.00%> (-24.08%) :arrow_down:
pymc/math.py 47.34% <0.00%> (-22.71%) :arrow_down:
pymc/distributions/continuous.py 76.31% <0.00%> (-21.17%) :arrow_down:
... and 22 more

codecov[bot] avatar May 28 '22 02:05 codecov[bot]

Nice! However, this isn't the same thing, it's a different parameterization of the same thing. Like centered vs non-centered. So I don't think we should remove the previous one as people might be using it and then changing the parameterization might screw with the inference.

However, we could add it either as an option or a separate dist.

twiecki avatar May 28 '22 09:05 twiecki

Nice! However, this isn't the same thing, it's a different parameterization of the same thing. Like centered vs non-centered. So I don't think we should remove the previous one as people might be using it and then changing the parameterization might screw with the inference.

However, we could add it either as an option or a separate dist.

What do you mean? This should be completely equivalent to what we had before.

ricardoV94 avatar May 28 '22 09:05 ricardoV94

You might be thinking of the latent parametrizations but that was not what GRW did before (as that wouldn't have been a vanilla distribution)

Something like that would just require implementing a default (diff) transform which would definitely be helpful for random walks.

Alternatively it would be useful to allow different implementations between unobserved and observed variables. Either eagerly at creation time or, even better, lazily when compiling the logp, but the latter requires some more work (make sure we have the right deterministics).

ricardoV94 avatar May 28 '22 09:05 ricardoV94

Isn't this changing the previous parameterization: $x_t \sim \mathcal{N}(x_{t-1}, \sigma)$ to $x_t \sim x_{t-1} + \mathcal{N}(0, \sigma)$

twiecki avatar May 31 '22 07:05 twiecki

Isn't this changing the previous parameterization: xt∼N(xt−1,σ) to xt∼xt−1+N(0,σ)

Where would the parametrization be changed?

The random generator method is exactly the same as before if you check what was going on in the old rng_fn. The logp is exactly the same as before (otherwise this test would fail): https://github.com/pymc-devs/pymc/blob/5703a9d24ca01da8b81363874e023e59cd18688d/pymc/tests/test_distributions.py#L2609-L2624

The expression $x_{t-1}+N(0,σ)$ "does not have a logp", it always has to be converted to $N(x_{t-1},σ)$, which Aeppl does when it sees it or, in this case, when it sees $x \sim \text{cumsum}(N(0, σ))$

Now for NUTS sampling (when we have unobserved GRW) it would probably be better if we were sampling in the latent space $x_{\text{raw}} \sim N(0, σ)$ and only then did the deterministic mapping to $x = \text{cumsum}(x_{\text{raw}})$, but this was not done before nor is it now.

For that we would need to either: 1) Add a new type of transform or 2) Have a way to create different graphs depending on whether the variables are observed or not.

ricardoV94 avatar May 31 '22 07:05 ricardoV94

I see, so there is no change from this implementation to the previous in v4, but it is different compared to v3, right?

twiecki avatar May 31 '22 07:05 twiecki

I see, so there is no change from this implementation to the previous in v4, but it is different compared to v3, right?

No, V3 did the same as well. The logp is written a bit different, but it's equivalent to the manual logp we had here in V4 up to this PR: https://github.com/pymc-devs/pymc/blob/ed74406735b2faf721e7ebfa156cc6828a5ae16e/pymc3/distributions/timeseries.py#L227-L246

And there was no special transformation either.

ricardoV94 avatar May 31 '22 07:05 ricardoV94

Oh right, I read you previous comment again. It's always been $x_t \sim N(x_{t-1}, \sigma)$, then this makes sense 👍

twiecki avatar May 31 '22 07:05 twiecki

cc @nicospinu for GSOC reference

This PR will end up being a good reference for a side by side comparison of how time series work when implemented in PyMC and how they'll be implemented if more integrated with AePPL

canyon289 avatar Jun 05 '22 19:06 canyon289

To my knowledge, three sets of tests are failing and they are definitely worth looking into.

  • (Static?) shape checks
  • pm.logp yields component-wise log-probability rather than a sum, which seems to be the case in the previous API
  • Inferring step sizes from shapes. Would this be possible with AePPL?

Would these issues be important to address? To my knowledge, some of these, especially static shape inference, would take substantial work to be included in Aesara. Happy to hear thoughts about this

Uncommenting the error pertaining to unexpected rv nodes here still yields many errors, so this would be important to look into

larryshamalama avatar Jun 05 '22 19:06 larryshamalama

To my knowledge, three sets of tests are failing and they are definitely worth looking into.

  • (Static?) shape checks

Can you expand?

  • pm.logp yields component-wise log-probability rather than a sum, which seems to be the case in the previous API

Yes, that behaves differently, but I don't think it poses a significant problem. We can just change the expectation of the tests.

  • Inferring step sizes from shapes. Would this be possible with AePPL?

That should work the same way as before. We do that in AR which is also a symbolic dist. It happens before any call to rv_op. It does not interfere or depend on Aeppl

ricardoV94 avatar Jun 05 '22 20:06 ricardoV94

To my knowledge, three sets of tests are failing and they are definitely worth looking into.

  • (Static?) shape checks

Can you expand?

Some tests that checked for shape inference were failing. They seem to have gone away whether this is related or not to the implemented change_size. However, it is incorrectly implemented. Can you have a look at this? Notably, test_shape_ellipsis is failing. I will tag you in a review.

  • pm.logp yields component-wise log-probability rather than a sum, which seems to be the case in the previous API

Yes, that behaves differently, but I don't think it poses a significant problem. We can just change the expectation of the tests.

In the tests, I summed up the logp terms, for now.

larryshamalama avatar Jun 07 '22 01:06 larryshamalama

Converted this PR back to draft. These lines probably warrant discussion as the implementation of the GaussianRandomWalk wouldn't work. Other things to be done:

  • Moment tests
  • Test to catch NotImplementedError in moment dispatching of CumOp
  • Investigating test_gaussianrandomwalk in test_distributions.py

larryshamalama avatar Jun 07 '22 16:06 larryshamalama

Still needs a test for moments

larryshamalama avatar Jun 29 '22 12:06 larryshamalama

Actually, I'm realizing that we need a dispatch for the DimShuffle op because moment(mu[..., None]) is returning a NotImplementedError. @ricardoV94 Does this sound right? This would be useful for the following model:

with pm.Model() as model:
    mu = pm.Normal("mu", 2, 3)
    sigma = pm.Gamma("sigma", 1, 1)
    grw = pm.GaussianRandomWalk(name="grw", mu=mu, sigma=sigma, init_dist=pm.StudentT.dist(5), steps=10)

larryshamalama avatar Jun 29 '22 16:06 larryshamalama

Actually, I'm realizing that we need a dispatch for the DimShuffle op because moment(mu[..., None]) is returning a NotImplementedError. @ricardoV94 Does this sound right? This would be useful for the following model:

with pm.Model() as model:
    mu = pm.Normal("mu", 2, 3)
    sigma = pm.Gamma("sigma", 1, 1)
    grw = pm.GaussianRandomWalk(name="grw", mu=mu, sigma=sigma, init_dist=pm.StudentT.dist(5), steps=10)

We can handle it in the same dispatch moment function, just have an if branch for that. It's the same discussion we had before of whether we specialize or create more generalized moment functions.

ricardoV94 avatar Jun 29 '22 16:06 ricardoV94

Actually, I'm realizing that we need a dispatch for the DimShuffle op because moment(mu[..., None]) is returning a NotImplementedError. @ricardoV94 Does this sound right? This would be useful for the following model:

with pm.Model() as model:
    mu = pm.Normal("mu", 2, 3)
    sigma = pm.Gamma("sigma", 1, 1)
    grw = pm.GaussianRandomWalk(name="grw", mu=mu, sigma=sigma, init_dist=pm.StudentT.dist(5), steps=10)

I realized that I could just do moment(mu)[..., None] instead 😅.

I just added tests for moments. I believe that this PR is ready to be merged, but moment tests are only passing locally so far and will be with PR 151 in AePPL merged.

larryshamalama avatar Jun 30 '22 11:06 larryshamalama

With #5955 soon to be merged, I believe that this PR would be good to go? This PR adds a hackish tweak that we don't check for random variables in the ancestors of a Join Op due to issue 149 in AePPL, so we can perhaps think of a more sophisticated way of dealing with this later.

larryshamalama avatar Jul 07 '22 14:07 larryshamalama

With #5955 soon to be merged, I believe that this PR would be good to go? This PR adds a hackish tweak that we don't check for random variables in the ancestors of a Join Op due to issue 149 in AePPL, so we can perhaps think of a more sophisticated way of dealing with this later.

I don't think we should proceed with the hack at all. It's mixing logic that should be separated. We should override aeppl join logprob instead as we have more strict requirements than they do, and do the constant fold of the shapes there.

ricardoV94 avatar Jul 07 '22 14:07 ricardoV94

Oh, my bad, and sounds good. It was recommended by @brandonwillard to edit a graph rewrite or register another one in the database for shape inference. ShapeFeature can also be added a listener on FunctionGraph calls if the former is not automatically included. This would be an AePPL-centred solution. An override within PyMC would also work, I suppose

larryshamalama avatar Jul 08 '22 14:07 larryshamalama

@larryshamalama do you feel like picking this up after #6072? I think that makes our lives easier here

ricardoV94 avatar Aug 29 '22 10:08 ricardoV94

Absolutely

larryshamalama avatar Aug 29 '22 11:08 larryshamalama

Actually it seems like #6072 will have to include this one by necessity

ricardoV94 avatar Aug 30 '22 17:08 ricardoV94

Done in #6072

ricardoV94 avatar Sep 05 '22 16:09 ricardoV94