lineax
lineax copied to clipboard
Linear solvers for sparse matrices
This is just a question: are there any plans to support JAX-compatible linear solvers for sparse matrices? I am thinking of sparse linear systems of the type $Ax=b$ where $A$ is sparse.
Right! So this is doable like so:
import equinox as eqx
import lineax as lx
import jax.experimental.sparse as js
import jax.numpy as jnp
class SparseMatrixLinearOperator(lx.MatrixLinearOperator):
def mv(self, vector):
return js.sparsify(lambda m, v: m @ v)(self.matrix, vector)
x = jnp.array([[1.0, 0.0], [0.0, 1.0]])
x = js.BCOO.fromdense(x)
op = SparseMatrixLinearOperator(x)
vec = jnp.array([1., 2.])
@eqx.filter_jit
def f(op, vec):
sol = lx.linear_solve(op, vec, solver=lx.GMRES(rtol=1e-5, atol=1e-5))
return sol
print(f(op, vec).value)
basically just (a) overriding the mv
(matrix-vector product) computation to work with a sparse matrix, and (b) using an iterative solver (here GMRES) rather than something like LU (which requires materialising the matrix).
This isn't built into Lineax by default as JAX's sparse support is still pretty experimental, and we want to see how that's going to pan out before we design an API around it.
Hello, @patrick-kidger.
I am a new user of lineax, I would like to ask if lineax supports sparse matrix AD now?
Only to whatever extent JAX itself does.
Thank you very much for your reply!