aesara icon indicating copy to clipboard operation
aesara copied to clipboard

Consider not creating `DimShuffle`s for `Elemwise` graphs

Open ricardoV94 opened this issue 2 years ago • 2 comments

import aesara
import aesara.tensor as at

x = at.vector("x")
y = at.matrix("y")
z = x + y
aesara.dprint(z, print_type="True")
Elemwise{add,no_inplace} [id A] <TensorType(float64, (None, None))> ''   
 |InplaceDimShuffle{x,0} [id B] <TensorType(float64, (1, None))> ''   
 | |x [id C] <TensorType(float64, (None,))>
 |y [id D] <TensorType(float64, (None, None))>

As an example our RandomVariables manage the broadcasting by themselves, leading to more succinct symbolic graphs

z = at.random.normal(x, y)
aesara.dprint(z, print_type="True")
normal_rv{0, (0, 0), floatX, False}.1 [id A] <TensorType(float64, (None, None))> ''   
 |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F81DB456880>) [id B] <RandomGeneratorType>
 |TensorConstant{[]} [id C] <TensorType(int64, (0,))>
 |TensorConstant{11} [id D] <TensorType(int64, ())>
 |x [id E] <TensorType(float64, (None,))>
 |y [id F] <TensorType(float64, (None, None))>

I am not sure what would be drawbacks of not using Dimshuffles would be on our C-backend. The other backends would probably be fine, if not better of.

ricardoV94 avatar Apr 26 '22 06:04 ricardoV94

Here's an example of why these DimShuffles are bad: they prevent one from replacing terms in a graph with new terms that have different dimensions.

import aesara
import aesara.tensor as at


x = at.scalar("x")
y = at.vector("y")

z = x + y

aesara.dprint(z, print_type=True)
# Elemwise{add,no_inplace} [id A] <TensorType(float64, (None,))> ''   
#  |InplaceDimShuffle{x} [id B] <TensorType(float64, (1,))> ''   
#  | |x [id C] <TensorType(float64, ())>
#  |y [id D] <TensorType(float64, (None,))>

# Attempt to clone and recreate the graph with a change that increases the
# dimensions of the `DimShuffle`ed term
w = at.vector("w")
replacements = {x: w}
z_new = aesara.clone_replace(z, replace=replacements, strict=False)
# TypeError: The number of dimensions of the input is incorrect for this op. Expected (), got (False,).

# Now, do the same for the non-`DimShuffle`ed term
s = at.matrix("s")
replacements = {y: s}
z_new = aesara.clone_replace(z, replace=replacements, strict=False)

aesara.dprint(z_new, print_type=True)
# Elemwise{add,no_inplace} [id A] <TensorType(float64, (None, None))> ''   
#  |InplaceDimShuffle{x,0} [id B] <TensorType(float64, (1, 1))> ''   
#  | |InplaceDimShuffle{x} [id C] <TensorType(float64, (1,))> ''   
#  |   |x [id D] <TensorType(float64, ())>
#  |s [id E] <TensorType(float64, (None, None))>

This situation could arise when one wants to plug a set of sampled values into a model graph and produce posterior predictive samples for each sampled value. Currently, DimShuffle is unnecessarily preventing this, as demonstrated above.

brandonwillard avatar Apr 29 '22 20:04 brandonwillard

As @ricardoV94 mentioned in our discussions about DimShuffle lifting, removing these DimShuffles created by Elemwise.make_node would help a lot.

The basic idea is that we can move the job of DimShuffleing (i.e. adding broadcastable dimensions to the inputs of an Elemwise so that they all have the same number of dimensions) to Elemwise.perform. In other words, we can make the Op itself do the DimShuffleing.

This would simplify our IR by removing the specificity of Elemwise inputs. The example above illustrates this by showing that our current graphs look like add(reshape(x, (1,)), y), which implies that x has a shape compatible with the reshape operation. The level of specificity imposed by such graphs prevents us from easily substituting x for another differently shaped term (e.g. one that is already a vector and doesn't need/can't have the reshape operation applied to it). Instead, if our graphs looked like add(x, y), with the add operation handling the shape disparities itself, replacing x with any other term that's shape-compatible with y would be simple.

The same goes for unifying/matching graphs; we need to construct elaborate DimShuffle-ignoring patterns and/or logic in order to work around this issue.

Also, at least one recently merged PR should help make this change possible (e.g. https://github.com/aesara-devs/aesara/pull/928). I think that the remaining work involves the addition of scalar argument support in Elemwise._c_all (and perhaps some other places).

brandonwillard avatar Jul 11 '22 21:07 brandonwillard