pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Equivalent dots not merged

Open ricardoV94 opened this issue 5 months ago • 0 comments

Description

In the example below we end up computing 4 dots, whereas only 3 are needed

import pytensor
import pytensor.tensor as pt
from pytensor.compile.mode import get_default_mode

A = pt.dmatrix('A')
x = pt.col('x')
f = (x.T @ A @ x), A @ x, A.T @ x

fn = pytensor.function([A, x], f, mode=get_default_mode().excluding("BlasOpt"))
fn.dprint()
# dot [id A] 2
#  ├─ dot [id B] 1
#  │  ├─ Transpose{axes=[1, 0]} [id C] 'x.T' 0
#  │  │  └─ x [id D]
#  │  └─ A [id E]
#  └─ x [id D]
# dot [id F] 3
#  ├─ A [id E]
#  └─ x [id D]
# dot [id G] 5
#  ├─ Transpose{axes=[1, 0]} [id H] 'A.T' 4
#  │  └─ A [id E]
#  └─ x [id D]

We could use associativity to write (x.T @ A) @ x -> as x.T @ (A @ x), where the inner dot is equivalent to the second output, so they could be merged

Alternatively, we could use the transpose rule to write the third output A.T @ x -> (x.T @ A).T which is the innermost dot in the first output, so they could be merged (with an extra transpose which is just a cheap view anyway).

With associativity we may want to be careful as order can impact a lot on performance. If we know the static shapes we can optimize like einsum does (see also #https://github.com/pymc-devs/pytensor/issues/961), but if we are already computing it anyway then we can't possible be doing worse.

The example above can easily show up in the gradient of a quadratic form graph.

ricardoV94 avatar Jul 13 '25 16:07 ricardoV94