Sparsity support in pytensor
Description
I'm investing implementing ALS in pytensor which is usually implemented with sparsity constructs (see implicit for reference). I quickly looked around and saw this older thread where someone asked for sparsity support. @jessegrabowski gave a first pass answer, but mentioned the support is subpar. Opening this to track any enhancements we could bring.
Is this specifically about implementing solve with sparse inputs, or do you have other Ops in mind?
The original issue seems to be just solve, I imagine that's good enough to start
So for solve it's easy enough to wrap sparse.linalg.spsolve for the C backend. We need gradients, but I found a JAX implementation here that we can copy, so that should be easy enough.
For the numba backend we need to write our own overrides, as I did for solve_discrete_are, for example. spsolve calls out to superlu, which I think we can hook into the same way we do for the other cLAPACK functions. There's also another optional package umfpack which scipy appears to prefer if it's available, but it would be a new dependency.
Finally for the Torch backend I honestly have no idea. It looks like torch has sparse support as well as an spsolve implementation, so it might be straight forward?
Great, thanks for the plan!
Out of curiosity, what could we do? I'm not aware of what general sparsity support would look like beyond making sure things can use csr and csc tensors, which i think pytensor already has sparse for that
Some thoughts:
- Agreement/harmonization on the status of the
SparseMatrixprimitive. I thought there was an issue or discussion about this, but for the life of me I can't find it @ricardoV94 . - Support for as full a suite of linear algebra operations on sparse matrices as possible. This page has the list of what I would consider to be possible. Interesting ones would be
spsolve,spsolve_triangular,eigs,svds, andkron. - Better rewrites/detection for when sparse matrices are involved in Ops, for example #321
- Support for batch dimensions, see #839
- Support for symbolic sparse jacobians from
pytensor.gradient.jacobian, see https://github.com/mfschubert/sparsejac
Agreement/harmonization on the status of the SparseMatrix primitive. I thought there was an issue or discussion about this, but for the life of me I can't find it @ricardoV94 .
I think this was discussed in the design-notes: https://github.com/pymc-devs/design-notes/blob/main/PyTensor%20design%20meeting%20(July%207%2C%202023).md#type-compatibility-across-backends
Does torch work with scipy-like CSC-CSR matrices? If so it should be easy to implement. JAX doesn't so the only thing we do for now is support operations on sparse constants -> dense outputs which we convert to the JAX compatible ones.
This PR touches on this as well, I should perhaps revive it: #278
For torch, it does have support for csc and csr formats. It also has a more generic one (coo) and then specialized formats. They do have some weird characteristics, you can't use them at all if you don't do .coalesce but maybe that's okay.
The upcoming implementation of INLA in pymc-extras has ramped up the value-add of improved sparsity support. I think @jessegrabowski mentioned adding Ops for scipy.sparse wouldn't be too heavy of a lift.
Reposting @theorashid's notes from https://github.com/pymc-devs/pymc-extras/issues/340:
INLA can work without it, but this is what will make it very quick and scalable and get it nearer to R-INLA performance. This would lie in https://github.com/pymc-devs/pytensor/tree/main/pytensor/sparse. There is a jax implementation of all the parts we need.
- [ ] Implement sparse matrix in pytensor with ability to:
- add to the diagonal
- multiply by a vector
- solve a linear system
- compute a log determinant
- [ ] Implement a MVN distribution with a sparse precision matrix