Enzyme-JAX icon indicating copy to clipboard operation
Enzyme-JAX copied to clipboard

Structured Matrix / Tensors Detection

Open avik-pal opened this issue 6 months ago • 1 comments

Listing out structured tensor that we can detect (ideally put these in a separate cc file and not with the HLO patterns):

  • [ ] UpperTriangular
  • [ ] UnitUpperTriangular
  • [ ] LowerTriangular
  • [ ] UnitLowerTriangular
  • [ ] Banded
  • [ ] Diagonal
  • [ ] Bidiagonal
  • [ ] Tridiagonal
  • [ ] Symmetric
  • [ ] Hermitian
  • [ ] BlockSparse
    • (usually, BlockDiagonal. ideally, block sparsity pattern is known at comptime)
  • [ ] OneHot

(Feel free to add more here)

These enable us to lower to faster factorization routines, matmul, etc.

avik-pal avatar Jun 20 '25 01:06 avik-pal

mmm what's your idea for how to implement this? like adding a MLIR type per structured tensor type? or maybe adding an attribute for marking ops that support a specific structure?

mofeing avatar Jun 20 '25 09:06 mofeing