pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

WIP: Add rewrite to fuse nested BlockDiag Ops

Open eby0303 opened this issue 2 months ago • 3 comments

Description

This is a draft PR for issue #1593. I’m setting up the local environment and exploring how to implement a rewrite that fuses nested BlockDiag Ops into a single one. I’ll update this PR with code once the setup is complete and I have an initial version of the rewrite.

Related Issue

  • [x] Closes #1593

Checklist

Type of change

  • [x] New feature / enhancement
  • [ ] Bug fix
  • [ ] Documentation
  • [ ] Maintenance
  • [ ] Other (please specify):

📚 Documentation preview 📚: https://pytensor--1671.org.readthedocs.build/en/1671/

eby0303 avatar Oct 16 '25 14:10 eby0303

Hey @jessegrabowski ! As suggested in the issue discussion, I’ve opened this draft PR to start working on the BlockDiag rewrite. I’ll be setting up locally. Please feel free to share any tips or guidance for where to begin Will update this PR as I make progress.

eby0303 avatar Oct 16 '25 14:10 eby0303

I'd suggest you work in a test-driven way. Add a test to tests/tensor/rewriting/test_linalg.py with a simple nested blockwise and count the number of BlockDiag ops, and assert that there is only 1. Confim that this test fails. Then add a rewrite to tensor/rewriting/linalg.py that looks for a blockwise with a blockwise inside, and if so merges them.

For an example of how to count ops in a graph for the test, look here (BUT the whole class is overkill for your case, just take the pieces from it and write an inline version).

For a good rewrite template to get you started, I think this one is pretty readable. You will need to 1) check that the input is a BlockDiag op, 2) check that at least one of the inputs to the BlockDiag is a BlockDiag, 3) pull out the inputs from the inner BlockDiag, 4) make a new BlockDiag with n_inputs = old_n_inputs + 1 and return it, passing in all 3 inputs.

jessegrabowski avatar Oct 16 '25 14:10 jessegrabowski

@jessegrabowski ! Added a rewrite to fuse nested BlockDiagonal ops and updated test_linalg.py with a test for nested BlockDiagonal fusion.

eby0303 avatar Oct 16 '25 16:10 eby0303

@jessegrabowski Added a rewrite to fuse nested BlockDiagonal ops into a single fused instance and included tests to verify fusion behavior, n_inputs, and output shape.

eby0303 avatar Nov 04 '25 17:11 eby0303