pytensor
pytensor copied to clipboard
Add rewrite to fuse nested `BlockDiag` Ops
Description
This comes up in statespace models. BlockDiag is variadic, so graphs like this can be canonicalized:
import pytensor.tensor as pt
import pytensor
a, b, c = pt.matrices('a', 'b', 'c')
x = pt.linalg.block_diag(pt.linalg.block_diag(a, b), c)
fn = pytensor.function([a, b, c], x)
fn.dprint()
Current output:
BlockDiagonal{n_inputs=2} [id A] 1
├─ BlockDiagonal{n_inputs=2} [id B] 0
│ ├─ a [id C]
│ └─ b [id D]
└─ c [id E]
Desired output:
BlockDiagonal{n_inputs=3} [id A] 0
├─ a [id B]
├─ b [id C]
└─ c [id D]
The only tricky wrinkle I can see is that there might be ExpandDims Ops sandwiched in between each "level" of BlockDiagonal. We should either push these inside or pull them out.