pytensor
pytensor copied to clipboard
Simplify dots with 1
Description
We have a local_0_dot_x
that removes useless dots with zero'd inputs. We don't seem to have anything for dots with ones as reported in https://github.com/pymc-devs/pytensor/discussions/637#discussioncomment-8405862
import pytensor
import pytensor.tensor as pt
from pytensor.compile.mode import get_default_mode
x = tn.col('x')
f = x @ [[1.]]
with pytensor.config.change_flags(optimizer_verbose=True):
fn = pytensor.function([x], f, mode=get_default_mode().excluding("BlasOpt"))
pytensor.dprint(fn)
dot [id A] 0
├─ x [id B]
└─ [[1.]] [id C]
I excluded the BlasOpt just to have a simpler graph, but it will still not rewrite it away with those, just add the more complex Blas Op.
https://github.com/pymc-devs/pytensor/blob/d3dd34e7ea78eb1f125dd771d06436d49bf2ce5d/pytensor/tensor/rewriting/math.py#L155-L190
Looks like an interesting issue. We'd just have to replace 0 with x in the local_0_dot_x
right?
Here's what I have in mind:
@register_canonicalize
@register_stabilize
@node_rewriter([Dot])
def local_1_dot_x(fgraph, node):
if not isinstance(node.op, Dot):
return False
x = node.inputs[0]
y = node.inputs[1]
replace = False
try:
if get_underlying_scalar_constant_value(x, only_process_constants=True) == 1:
replace = True
var = y
except NotScalarConstantError:
pass
try:
if get_underlying_scalar_constant_value(y, only_process_constants=True) == 1:
replace = True
var=x
except NotScalarConstantError:
pass
if replace:
constant_value = constant(get_underlying_scalar_constant_value(var, only_process_constants=True), dtype=node.outputs[0].type.dtype)
if x.ndim == 2 and y.ndim == 2:
constant_value = assert_op(constant_value, eq(x.shape[1], y.shape[0]))
return [alloc(constant_value, x.shape[0], y.shape[1])]
elif x.ndim == 1 and y.ndim == 2:
constant_value = assert_op(constant_value, eq(x.shape[0], y.shape[0]))
return [alloc(constant_value, y.shape[1])]
elif x.ndim == 2 and y.ndim == 1:
constant_value = assert_op(constant_value, eq(x.shape[1], y.shape[0]))
return [alloc(constant_value, x.shape[0])]
elif x.ndim == 1 and y.ndim == 1:
constant_value = assert_op(constant_value, eq(x.shape[0], y.shape[0]))
return [constant_value]
However, I think using constant value might be wrong here. Will I have to replace with the entire var itself? If yes, then is this the correct way of moving forward?
var=assert_op(var, eq(...)
alloc(var, shape)
No, the rule is slightly different for ones, as it consists of summing the left matrix. Also have to reason about broadcasting.
I suggest playing with numpy to get a feel of what it should do.
Ohk. Just so that I get it correctly, for a given graph say
Sub [id A]
├─ dot [id B]
│ ├─ dot [id C]
│ │ ├─ Transpose{axes=[1, 0]} [id D] 'A.T'
│ │ │ └─ A [id E]
│ │ └─ Neg [id F]
│ │ └─ x [id G]
│ └─ [[1.]] [id H]
└─ dot [id I]
├─ A [id E]
└─ dot [id J]
├─ x [id G]
└─ [[1.]] [id H]
we want the output of the rewrite to be:
Sub [id A]
├─ dot [id B]
│ ├─ Transpose{axes=[1, 0]} [id C] 'A.T'
│ │ └─ A [id D]
│ └─ Neg [id E]
│ └─ x [id F]
└─ dot [id G]
├─ A [id D]
└─ x [id F]
Is this correct? And if yes, how does summing of left matrices and broadcasting come into picture here?