pymc
pymc copied to clipboard
Incorporating AePPL into `GaussianRandomWalk`
Closes #5762.
This PR removes GaussianRandomWalkRV
altogether by defining an rv_op
as a cumulative sum of distributions. The most recent version of AePPL, i.e. 0.0.31, is now able to retrieve the appropriate logp graph from a CumOp
.
This is a WIP for now as I figure out how to use AePPL...
Codecov Report
Merging #5814 (9a7badd) into main (906fcdc) will decrease coverage by
6.24%
. The diff coverage is93.18%
.
:exclamation: Current head 9a7badd differs from pull request most recent head dbeb801. Consider uploading reports for the commit dbeb801 to get more accurate results
@@ Coverage Diff @@
## main #5814 +/- ##
==========================================
- Coverage 89.26% 83.01% -6.25%
==========================================
Files 72 73 +1
Lines 12890 13261 +371
==========================================
- Hits 11506 11009 -497
- Misses 1384 2252 +868
Impacted Files | Coverage Δ | |
---|---|---|
pymc/distributions/timeseries.py | 77.55% <92.85%> (-1.09%) |
:arrow_down: |
pymc/distributions/logprob.py | 89.55% <100.00%> (-8.18%) |
:arrow_down: |
pymc/func_utils.py | 22.50% <0.00%> (-65.88%) |
:arrow_down: |
pymc/distributions/bound.py | 45.54% <0.00%> (-54.46%) |
:arrow_down: |
pymc/distributions/mixture.py | 59.35% <0.00%> (-36.37%) |
:arrow_down: |
pymc/exceptions.py | 68.96% <0.00%> (-31.04%) |
:arrow_down: |
pymc/distributions/discrete.py | 71.35% <0.00%> (-27.87%) |
:arrow_down: |
pymc/tuning/scaling.py | 35.18% <0.00%> (-24.08%) |
:arrow_down: |
pymc/math.py | 47.34% <0.00%> (-22.71%) |
:arrow_down: |
pymc/distributions/continuous.py | 76.31% <0.00%> (-21.17%) |
:arrow_down: |
... and 22 more |
Nice! However, this isn't the same thing, it's a different parameterization of the same thing. Like centered vs non-centered. So I don't think we should remove the previous one as people might be using it and then changing the parameterization might screw with the inference.
However, we could add it either as an option or a separate dist.
Nice! However, this isn't the same thing, it's a different parameterization of the same thing. Like centered vs non-centered. So I don't think we should remove the previous one as people might be using it and then changing the parameterization might screw with the inference.
However, we could add it either as an option or a separate dist.
What do you mean? This should be completely equivalent to what we had before.
You might be thinking of the latent parametrizations but that was not what GRW did before (as that wouldn't have been a vanilla distribution)
Something like that would just require implementing a default (diff) transform which would definitely be helpful for random walks.
Alternatively it would be useful to allow different implementations between unobserved and observed variables. Either eagerly at creation time or, even better, lazily when compiling the logp, but the latter requires some more work (make sure we have the right deterministics).
Isn't this changing the previous parameterization: $x_t \sim \mathcal{N}(x_{t-1}, \sigma)$ to $x_t \sim x_{t-1} + \mathcal{N}(0, \sigma)$
Isn't this changing the previous parameterization: xt∼N(xt−1,σ) to xt∼xt−1+N(0,σ)
Where would the parametrization be changed?
The random generator method is exactly the same as before if you check what was going on in the old rng_fn
. The logp is exactly the same as before (otherwise this test would fail): https://github.com/pymc-devs/pymc/blob/5703a9d24ca01da8b81363874e023e59cd18688d/pymc/tests/test_distributions.py#L2609-L2624
The expression $x_{t-1}+N(0,σ)$ "does not have a logp", it always has to be converted to $N(x_{t-1},σ)$, which Aeppl does when it sees it or, in this case, when it sees $x \sim \text{cumsum}(N(0, σ))$
Now for NUTS sampling (when we have unobserved GRW) it would probably be better if we were sampling in the latent space $x_{\text{raw}} \sim N(0, σ)$ and only then did the deterministic mapping to $x = \text{cumsum}(x_{\text{raw}})$, but this was not done before nor is it now.
For that we would need to either: 1) Add a new type of transform or 2) Have a way to create different graphs depending on whether the variables are observed or not.
I see, so there is no change from this implementation to the previous in v4, but it is different compared to v3, right?
I see, so there is no change from this implementation to the previous in v4, but it is different compared to v3, right?
No, V3 did the same as well. The logp is written a bit different, but it's equivalent to the manual logp we had here in V4 up to this PR: https://github.com/pymc-devs/pymc/blob/ed74406735b2faf721e7ebfa156cc6828a5ae16e/pymc3/distributions/timeseries.py#L227-L246
And there was no special transformation either.
Oh right, I read you previous comment again. It's always been $x_t \sim N(x_{t-1}, \sigma)$, then this makes sense 👍
cc @nicospinu for GSOC reference
This PR will end up being a good reference for a side by side comparison of how time series work when implemented in PyMC and how they'll be implemented if more integrated with AePPL
To my knowledge, three sets of tests are failing and they are definitely worth looking into.
- (Static?) shape checks
-
pm.logp
yields component-wise log-probability rather than a sum, which seems to be the case in the previous API - Inferring step sizes from shapes. Would this be possible with AePPL?
Would these issues be important to address? To my knowledge, some of these, especially static shape inference, would take substantial work to be included in Aesara. Happy to hear thoughts about this
Uncommenting the error pertaining to unexpected rv nodes here still yields many errors, so this would be important to look into
To my knowledge, three sets of tests are failing and they are definitely worth looking into.
- (Static?) shape checks
Can you expand?
pm.logp
yields component-wise log-probability rather than a sum, which seems to be the case in the previous API
Yes, that behaves differently, but I don't think it poses a significant problem. We can just change the expectation of the tests.
- Inferring step sizes from shapes. Would this be possible with AePPL?
That should work the same way as before. We do that in AR which is also a symbolic dist. It happens before any call to rv_op. It does not interfere or depend on Aeppl
To my knowledge, three sets of tests are failing and they are definitely worth looking into.
- (Static?) shape checks
Can you expand?
Some tests that checked for shape inference were failing. They seem to have gone away whether this is related or not to the implemented change_size
. However, it is incorrectly implemented. Can you have a look at this? Notably, test_shape_ellipsis
is failing. I will tag you in a review.
pm.logp
yields component-wise log-probability rather than a sum, which seems to be the case in the previous APIYes, that behaves differently, but I don't think it poses a significant problem. We can just change the expectation of the tests.
In the tests, I summed up the logp
terms, for now.
Converted this PR back to draft. These lines probably warrant discussion as the implementation of the GaussianRandomWalk
wouldn't work. Other things to be done:
- Moment tests
- Test to catch
NotImplementedError
in moment dispatching ofCumOp
- Investigating
test_gaussianrandomwalk
intest_distributions.py
Still needs a test for moments
Actually, I'm realizing that we need a dispatch for the DimShuffle
op because moment(mu[..., None])
is returning a NotImplementedError
. @ricardoV94 Does this sound right? This would be useful for the following model:
with pm.Model() as model:
mu = pm.Normal("mu", 2, 3)
sigma = pm.Gamma("sigma", 1, 1)
grw = pm.GaussianRandomWalk(name="grw", mu=mu, sigma=sigma, init_dist=pm.StudentT.dist(5), steps=10)
Actually, I'm realizing that we need a dispatch for the
DimShuffle
op becausemoment(mu[..., None])
is returning aNotImplementedError
. @ricardoV94 Does this sound right? This would be useful for the following model:with pm.Model() as model: mu = pm.Normal("mu", 2, 3) sigma = pm.Gamma("sigma", 1, 1) grw = pm.GaussianRandomWalk(name="grw", mu=mu, sigma=sigma, init_dist=pm.StudentT.dist(5), steps=10)
We can handle it in the same dispatch moment function, just have an if branch for that. It's the same discussion we had before of whether we specialize or create more generalized moment functions.
Actually, I'm realizing that we need a dispatch for the
DimShuffle
op becausemoment(mu[..., None])
is returning aNotImplementedError
. @ricardoV94 Does this sound right? This would be useful for the following model:with pm.Model() as model: mu = pm.Normal("mu", 2, 3) sigma = pm.Gamma("sigma", 1, 1) grw = pm.GaussianRandomWalk(name="grw", mu=mu, sigma=sigma, init_dist=pm.StudentT.dist(5), steps=10)
I realized that I could just do moment(mu)[..., None]
instead 😅.
I just added tests for moments. I believe that this PR is ready to be merged, but moment tests are only passing locally so far and will be with PR 151 in AePPL merged.
With #5955 soon to be merged, I believe that this PR would be good to go? This PR adds a hackish tweak that we don't check for random variables in the ancestors of a Join
Op due to issue 149 in AePPL, so we can perhaps think of a more sophisticated way of dealing with this later.
With #5955 soon to be merged, I believe that this PR would be good to go? This PR adds a hackish tweak that we don't check for random variables in the ancestors of a
Join
Op due to issue 149 in AePPL, so we can perhaps think of a more sophisticated way of dealing with this later.
I don't think we should proceed with the hack at all. It's mixing logic that should be separated. We should override aeppl join logprob instead as we have more strict requirements than they do, and do the constant fold of the shapes there.
Oh, my bad, and sounds good. It was recommended by @brandonwillard to edit a graph rewrite or register another one in the database for shape inference. ShapeFeature
can also be added a listener on FunctionGraph
calls if the former is not automatically included. This would be an AePPL-centred solution. An override within PyMC would also work, I suppose
@larryshamalama do you feel like picking this up after #6072? I think that makes our lives easier here
Absolutely
Actually it seems like #6072 will have to include this one by necessity
Done in #6072