aesara
aesara copied to clipboard
Implement lifting for `Subtensor`s and `DimShuffle`s applied to multivariate `RandomVariable`s
Our current rewrites, aesara.tensor.random.opt.local_subtensor_rv_lift
and local_dimshuffle_rv_lift
do not support multivariate RandomVariable
s.
This is in part due to a fundamental restriction involving the way inputs are mapped to support dimensions (e.g. DirichletRV
), but not exclusively.
For at least the DimShuffle
case, it's possible that an additional parameter indicating the support dimension(s) could be used to work around this restriction.
From https://github.com/aesara-devs/aeppl/issues/150#issuecomment-1179253553:
The reason we can't lift the
DimShuffle
throughDirichletRV
is thatDirichletRV
fixes the last dimension of its inputs as the core/support dimension, so thatz.T
forz = dirichlet(a)
results in a sample that has no possible correspondingf(a)
. However, if we parameterizeDirichletRV
on its core/support dimension we getz == dirichlet(a, c=-1)
andz.T == dirichlet(a.T, c=0)
.
If I understand your proposal, at least for the Dimshuffle case, it doesn't seem that the ability to lift the operator justifies adding an extra parameter to all multivariate distributions.
It should however be safe to lift it through non-support dimensions, although I am a bit fuzzy about how you would do that for parameters with different number of core dimensions.
It would be nice to lift a subtensor operation but that can definitely only be done if it does not remove partial entries from the support dimensions.
If I understand your proposal, at least for the Dimshuffle case, it doesn't seem that the ability to lift the operator justifies adding an extra parameter to all multivariate distributions.
The costs of adding such a parameter, especially at the Op
-level, are extremely low, and the benefits are rather high (e.g. see https://github.com/aesara-devs/aeppl/issues/150#issuecomment-1179412394 for some explanations of the latter).
It should however be safe to lift it through non-support dimensions, although I am a bit fuzzy about how you would do that for parameters with different number of core dimensions.
It would be nice to lift a subtensor operation but that can definitely only be done if it does not remove partial entries from the support dimensions.
See the docstrings and comments in those rewrites; both should provide high and low-level descriptions of what they're doing and why.
The costs of adding such a parameter, especially at the Op-level, are extremely low,
Just to be clear, here I meant developer costs, not performance costs.
I guess my confusion is why do we care about lifting DimShuffles in particular, and not other operators like exp
(you could do something clever with loc/scale distributions)?
Obviously it's easier to manipulate shapes. But being easier doesn't make it more relevant.
My guess is that the "issue" comes from Dimshuffles being introduced very frequently by Elemwise operators? But those cases are relatively simple to eliminate by just appending 1
s to the (left of the size) of the RV. Dimshuffles introduced by Elemwise operators will never expand on the support dimension of the RV.
Alternatively we could just stop adding Dimshuffles for Elemwise operations as is discussed in another issue.
I guess my confusion is why do we care about lifting DimShuffles in particular, and not other operators like
exp
(you could do something clever with loc/scale distributions)?
This is not an "either ... or ..." situation, and lifting exponentials through RandomVariable
s is not even a thing—at least not without producing an entirely different RandomVariable
in most cases. If you're referring to completely different rewrites, then that's off topic for this issue.
Obviously it's easier to manipulate shapes. But being easier doesn't make it more relevant.
If you think this specific type of rewrite, or RandomVariable
canonicalizations in general, aren't relevant, then try removing them from AePPL and see if everything still works. More importantly, you'll need to do that and guarantee that DimShuffle
s won't be a problem when they're applied to RandomVariable
s for any and every other type of model and inputs as well, and not just in AePPL.
My guess is that the "issue" comes from Dimshuffles being introduced very frequently by Elemwise operators? But those cases are relatively simple to eliminate by just appending
1
s to the (left of the size) of the RV. Dimshuffles introduced by Elemwise operators will never expand on the support dimension of the RV.
The "issue" is very real and relevant, because DimShuffle
s can and do appear for multiple reasons (e.g. automatically introduced by Op
s, rewrites, users) and they unequivocally confound the process of identifying subgraphs that contain RandomVariable
s. That reason alone is enough justification for doing this work.
No, appending 1
s will not fix the same issues solved by lifting DimShuffle
s. All you're proposing is that we somehow standardize the number of dimensions of RandomVariable
s, but when and where does that make sense, and what's the correct number of dimensions to use in any given instance? There's no way to make a canonicalization—or broadly applicable rewrite—out of such an approach.
Alternatively we could just stop adding Dimshuffles for Elemwise operations as is discussed in another issue.
That would fix nothing; it would only somewhat reduce the occurence of DimShuffle
s applied to RandomVariable
s. There is no reasonable way to eliminate them entirely, because it would imply that Aesara wouldn't even support DimShuffles
of RandomVariable
s.
If you think this specific type of rewrite, or RandomVariable canonicalizations in general, aren't relevant, then try removing them from AePPL and see if everything still works
I think it would still work, you only need to convert them to a "MeasurableDimShuffle". Our rewrites aren't looking for pure RVs anymore, and being a MeasurableCumsum or MeasurableDimShuffle should pose exactly the same limitations. I fail to see what is special about Dimshuffles that we want to get rid of them (other than the fact we can).
Outside of Aeppl, it seems to me that your proposal above transfers the responsibility of Dimshuffling to RandomVariables perform method, by adding an auxiliary parameter that determines the axis of support.
That's the part I don't see the point of. Again, because I don't see what is the problem that Dimshuffles introduce.
No, appending 1s will not fix the same issues solved by lifting DimShuffles. All you're proposing is that we somehow standardize the number of dimensions of RandomVariables, but when and where does that make sense, and what's the correct number of dimensions to use in any given instance? There's no way to make a canonicalization—or broadly applicable rewrite—out of such an approach.
It seems to me like we can always do it. Whether we want or not is a different question. The same I am raising for the more general lift.
Conceptually it is similar to what you are proposing except it relies on already supported behavior of RVs and not on a new feature.
at.random.dirichlet([1, 2, 3])[None]
-> at.random.dirichlet([1, 2, 3], size=1)
.
In general, we can always easily canonicalize, implicit batched dimensions to size, but we don't have to.
at.random.dirichlet(np.ones((5, 3))[None]
-> at.random.dirichlet(np.ones((5, 3)), size=5)[None]
-> at.random.dirichlet(np.ones((5, 3)), size=(1, 5))
.
We can skip the intermediate step and go directly to the last one in a single rewrite.
If you have multiple parameters it's equally easy, but you need to broadcast the (batched) shapes of the parameters first. We have that machinery already written in Aesara as well.
Anyway this is a tangential proposal to expand the scope of the DS lift rewrite that already exists, to special cases that don't interfere with core dimensions and are applicable to both univariate and multivariate distributions alike. It's not an or question.
because it would imply that Aesara wouldn't even support DimShuffles of RandomVariables.
I don't follow what you mean here.
I think it would still work, you only need to convert them to a "MeasurableDimShuffle". Our rewrites aren't looking for pure RVs anymore, and being a MeasurableCumsum or MeasurableDimShuffle should pose exactly the same limitations. I fail to see what is special about Dimshuffles that we want to get rid of them (other than the fact we can).
Outside of Aeppl, it seems to me that your proposal above transfers the responsibility of Dimshuffling to RandomVariables perform method, by adding an auxiliary parameter that determines the axis of support.
This is not the AePPL repo, and a solution in AePPL is not a solution here or anywhere else that RandomVariable
s are used.
Also, I'm well aware of how AePPL works now, because I made it work that way; that's also why I'm able to say that the MeasurableDimShuffle
approach isn't ideal. It's a very AePPL-specific hack that does nothing to address a more general lack of functionality.
Before going any further with this topic, please tell me how your alternative(s) address the basic rewriting challenges introduced by DimShuffle
s. If they don't address those challenges, then they're not comparable.
If you don't know what those rewriting challenges are, then start your questions there, because I've mentioned that they're an important reason for doing this work, but you keep avoiding this point.
It seems to me like we can always do it. Whether we want or not is a different question. The same I am raising for the more general lift. ...
After this point, the rewriting you describe is essentially what the existing DimShuffle
lifting does (e.g. moving DimShuffle
s into the size
parameter and the like), so I get the feeling that you're not actually addressing the points I raised. For instance, when/where are those [None]
's appearing, and how are they getting there?
If the answer(s) are specific to the Elemwise
s you mentioned, then you're effectively trying to trivialize this topic by assuming that all/most DimShuffle
s of RandomVariable
s come from that one narrow situation and can, thus, be handled in a simple context-dependent fashion. One that apparently still depends on these lifting rewrites that you're arguing against?
Otherwise, let's keep these discussions on the topic of this issue instead: i.e. adding multivariate support to RandomVariable
lifting rewrites.
Outside of Aeppl, it seems to me that your proposal above transfers the responsibility of Dimshuffling to RandomVariables perform method, by adding an auxiliary parameter that determines the axis of support.
You touch upon a very relevant point here, and I'll follow up on it later, but the basic idea is that we might be able to absorb that DimShuffle
ing within our existing RandomVariable.perform
steps, especially the multivariate broadcasting ones.