pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Sparsity support in pytensor

Open Ch0ronomato opened this issue 1 year ago • 8 comments

Description

I'm investing implementing ALS in pytensor which is usually implemented with sparsity constructs (see implicit for reference). I quickly looked around and saw this older thread where someone asked for sparsity support. @jessegrabowski gave a first pass answer, but mentioned the support is subpar. Opening this to track any enhancements we could bring.

Ch0ronomato avatar Dec 15 '24 17:12 Ch0ronomato

Is this specifically about implementing solve with sparse inputs, or do you have other Ops in mind?

jessegrabowski avatar Dec 15 '24 18:12 jessegrabowski

The original issue seems to be just solve, I imagine that's good enough to start

Ch0ronomato avatar Dec 15 '24 20:12 Ch0ronomato

So for solve it's easy enough to wrap sparse.linalg.spsolve for the C backend. We need gradients, but I found a JAX implementation here that we can copy, so that should be easy enough.

For the numba backend we need to write our own overrides, as I did for solve_discrete_are, for example. spsolve calls out to superlu, which I think we can hook into the same way we do for the other cLAPACK functions. There's also another optional package umfpack which scipy appears to prefer if it's available, but it would be a new dependency.

Finally for the Torch backend I honestly have no idea. It looks like torch has sparse support as well as an spsolve implementation, so it might be straight forward?

jessegrabowski avatar Dec 15 '24 20:12 jessegrabowski

Great, thanks for the plan!

Out of curiosity, what could we do? I'm not aware of what general sparsity support would look like beyond making sure things can use csr and csc tensors, which i think pytensor already has sparse for that

Ch0ronomato avatar Dec 15 '24 22:12 Ch0ronomato

Some thoughts:

  1. Agreement/harmonization on the status of the SparseMatrix primitive. I thought there was an issue or discussion about this, but for the life of me I can't find it @ricardoV94 .
  2. Support for as full a suite of linear algebra operations on sparse matrices as possible. This page has the list of what I would consider to be possible. Interesting ones would be spsolve, spsolve_triangular, eigs, svds, and kron.
  3. Better rewrites/detection for when sparse matrices are involved in Ops, for example #321
  4. Support for batch dimensions, see #839
  5. Support for symbolic sparse jacobians from pytensor.gradient.jacobian, see https://github.com/mfschubert/sparsejac

jessegrabowski avatar Dec 16 '24 01:12 jessegrabowski

Agreement/harmonization on the status of the SparseMatrix primitive. I thought there was an issue or discussion about this, but for the life of me I can't find it @ricardoV94 .

I think this was discussed in the design-notes: https://github.com/pymc-devs/design-notes/blob/main/PyTensor%20design%20meeting%20(July%207%2C%202023).md#type-compatibility-across-backends

Does torch work with scipy-like CSC-CSR matrices? If so it should be easy to implement. JAX doesn't so the only thing we do for now is support operations on sparse constants -> dense outputs which we convert to the JAX compatible ones.

This PR touches on this as well, I should perhaps revive it: #278

ricardoV94 avatar Jan 05 '25 11:01 ricardoV94

For torch, it does have support for csc and csr formats. It also has a more generic one (coo) and then specialized formats. They do have some weird characteristics, you can't use them at all if you don't do .coalesce but maybe that's okay.

Ch0ronomato avatar Jan 16 '25 04:01 Ch0ronomato

The upcoming implementation of INLA in pymc-extras has ramped up the value-add of improved sparsity support. I think @jessegrabowski mentioned adding Ops for scipy.sparse wouldn't be too heavy of a lift.

Reposting @theorashid's notes from https://github.com/pymc-devs/pymc-extras/issues/340:

INLA can work without it, but this is what will make it very quick and scalable and get it nearer to R-INLA performance. This would lie in https://github.com/pymc-devs/pytensor/tree/main/pytensor/sparse. There is a jax implementation of all the parts we need.

  • [ ] Implement sparse matrix in pytensor with ability to:
  • add to the diagonal
  • multiply by a vector
  • solve a linear system
  • compute a log determinant
  • [ ] Implement a MVN distribution with a sparse precision matrix

ColtAllen avatar Aug 30 '25 14:08 ColtAllen