pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Add rewrite to fuse nested `BlockDiag` Ops

Open jessegrabowski opened this issue 3 months ago • 2 comments

Description

This comes up in statespace models. BlockDiag is variadic, so graphs like this can be canonicalized:

import pytensor.tensor as pt
import pytensor
a, b, c = pt.matrices('a', 'b', 'c')
x = pt.linalg.block_diag(pt.linalg.block_diag(a, b), c)
fn = pytensor.function([a, b, c], x)
fn.dprint()

Current output:

BlockDiagonal{n_inputs=2} [id A] 1
 ├─ BlockDiagonal{n_inputs=2} [id B] 0
 │  ├─ a [id C]
 │  └─ b [id D]
 └─ c [id E]

Desired output:

BlockDiagonal{n_inputs=3} [id A] 0
 ├─ a [id B]
 ├─ b [id C]
 └─ c [id D]

The only tricky wrinkle I can see is that there might be ExpandDims Ops sandwiched in between each "level" of BlockDiagonal. We should either push these inside or pull them out.

jessegrabowski avatar Aug 26 '25 06:08 jessegrabowski