aesara icon indicating copy to clipboard operation
aesara copied to clipboard

`local_dimshuffle_rv_lift` does not handle dimensions dropped by `DimShuffle`

Open ricardoV94 opened this issue 2 years ago • 4 comments

import aesara
import aesara.tensor as at
from aesara.graph import FunctionGraph
from aesara.tensor.random.opt import local_dimshuffle_rv_lift

x = at.random.normal(0, 1, size=(1, 2)).dimshuffle(1)
fg = FunctionGraph(outputs=[x])
print(local_dimshuffle_rv_lift.transform(fg, x.owner))  # False

x = at.random.normal([[0, 0]], 1, size=(1, 2)).dimshuffle(1)
fg = FunctionGraph(outputs=[x])
local_dimshuffle_rv_lift.transform(fg, x.owner)  # raises ValueError

ricardoV94 avatar Aug 15 '22 08:08 ricardoV94

Also noticed, that when size broadcasts parameters, the rewrite is not applied:

import aesara
import aesara.tensor as at
from aesara.graph import FunctionGraph
from aesara.tensor.random.opt import local_dimshuffle_rv_lift
import numpy as np

x = at.random.normal(np.zeros((4, 3, 2)), 1, size=(4, 3, 2)).T
fg = FunctionGraph(outputs=[x])
assert local_dimshuffle_rv_lift.transform(fg, x.owner)  # Fine

x = at.random.normal(np.zeros((3, 2)), 1, size=(4, 3, 2)).T
fg = FunctionGraph(outputs=[x])
assert local_dimshuffle_rv_lift.transform(fg, x.owner)  # Fails

ricardoV94 avatar Aug 21 '22 10:08 ricardoV94

Please print the debug graphs before and after rewriting.

brandonwillard avatar Aug 21 '22 15:08 brandonwillard

Please print the debug graphs before and after rewriting.

I am calling the rewrite manually, the unsupported cases simply return False, so there's nothing to print.

The only other case is the one mentioned above that raises an error in the Dimshuffle.

ricardoV94 avatar Aug 21 '22 18:08 ricardoV94

I am calling the rewrite manually, the unsupported cases simply return False, so there's nothing to print.

The only other case is the one mentioned above that raises an error in the Dimshuffle.

I asked so that we could show the input graphs and their Type information, and mentioned the after-rewriting case just for when/if it is relevant.

brandonwillard avatar Aug 21 '22 21:08 brandonwillard