aesara
aesara copied to clipboard
`local_dimshuffle_rv_lift` does not handle dimensions dropped by `DimShuffle`
import aesara
import aesara.tensor as at
from aesara.graph import FunctionGraph
from aesara.tensor.random.opt import local_dimshuffle_rv_lift
x = at.random.normal(0, 1, size=(1, 2)).dimshuffle(1)
fg = FunctionGraph(outputs=[x])
print(local_dimshuffle_rv_lift.transform(fg, x.owner)) # False
x = at.random.normal([[0, 0]], 1, size=(1, 2)).dimshuffle(1)
fg = FunctionGraph(outputs=[x])
local_dimshuffle_rv_lift.transform(fg, x.owner) # raises ValueError
Also noticed, that when size broadcasts parameters, the rewrite is not applied:
import aesara
import aesara.tensor as at
from aesara.graph import FunctionGraph
from aesara.tensor.random.opt import local_dimshuffle_rv_lift
import numpy as np
x = at.random.normal(np.zeros((4, 3, 2)), 1, size=(4, 3, 2)).T
fg = FunctionGraph(outputs=[x])
assert local_dimshuffle_rv_lift.transform(fg, x.owner) # Fine
x = at.random.normal(np.zeros((3, 2)), 1, size=(4, 3, 2)).T
fg = FunctionGraph(outputs=[x])
assert local_dimshuffle_rv_lift.transform(fg, x.owner) # Fails
Please print the debug graphs before and after rewriting.
Please print the debug graphs before and after rewriting.
I am calling the rewrite manually, the unsupported cases simply return False, so there's nothing to print.
The only other case is the one mentioned above that raises an error in the Dimshuffle.
I am calling the rewrite manually, the unsupported cases simply return False, so there's nothing to print.
The only other case is the one mentioned above that raises an error in the Dimshuffle.
I asked so that we could show the input graphs and their Type
information, and mentioned the after-rewriting case just for when/if it is relevant.