lineax icon indicating copy to clipboard operation
lineax copied to clipboard

Linear solvers for sparse matrices

Open arpastrana opened this issue 1 year ago • 4 comments

This is just a question: are there any plans to support JAX-compatible linear solvers for sparse matrices? I am thinking of sparse linear systems of the type $Ax=b$ where $A$ is sparse.

arpastrana avatar Jun 07 '23 22:06 arpastrana

Right! So this is doable like so:

import equinox as eqx
import lineax as lx
import jax.experimental.sparse as js
import jax.numpy as jnp

class SparseMatrixLinearOperator(lx.MatrixLinearOperator):
    def mv(self, vector):
        return js.sparsify(lambda m, v: m @ v)(self.matrix, vector)

x = jnp.array([[1.0, 0.0], [0.0, 1.0]])
x = js.BCOO.fromdense(x)
op = SparseMatrixLinearOperator(x)
vec = jnp.array([1., 2.])

@eqx.filter_jit
def f(op, vec):
    sol = lx.linear_solve(op, vec, solver=lx.GMRES(rtol=1e-5, atol=1e-5))
    return sol

print(f(op, vec).value)

basically just (a) overriding the mv (matrix-vector product) computation to work with a sparse matrix, and (b) using an iterative solver (here GMRES) rather than something like LU (which requires materialising the matrix).

This isn't built into Lineax by default as JAX's sparse support is still pretty experimental, and we want to see how that's going to pan out before we design an API around it.

patrick-kidger avatar Jun 07 '23 23:06 patrick-kidger

Hello, @patrick-kidger.

I am a new user of lineax, I would like to ask if lineax supports sparse matrix AD now?

DoTulip avatar Apr 07 '24 04:04 DoTulip

Only to whatever extent JAX itself does.

patrick-kidger avatar Apr 07 '24 07:04 patrick-kidger

Thank you very much for your reply!

DoTulip avatar Apr 07 '24 09:04 DoTulip