Enzyme-JAX
Enzyme-JAX copied to clipboard
Structured Matrix / Tensors Detection
Listing out structured tensor that we can detect (ideally put these in a separate cc file and not with the HLO patterns):
- [ ] UpperTriangular
- [ ] UnitUpperTriangular
- [ ] LowerTriangular
- [ ] UnitLowerTriangular
- [ ] Banded
- [ ] Diagonal
- [ ] Bidiagonal
- [ ] Tridiagonal
- [ ] Symmetric
- [ ] Hermitian
- [ ] BlockSparse
- (usually, BlockDiagonal. ideally, block sparsity pattern is known at comptime)
- [ ] OneHot
(Feel free to add more here)
These enable us to lower to faster factorization routines, matmul, etc.
mmm what's your idea for how to implement this? like adding a MLIR type per structured tensor type? or maybe adding an attribute for marking ops that support a specific structure?