lineax
lineax copied to clipboard
Issue with `vmap` when using `lx.linear_solve` on `SparseMatrixOperator` with multi-column RHS
I'm encountering a ValueError when using the vmap
functionality to map the lx.linear_solve
operation across multiple columns of a right-hand side matrix (RHS) with a SparseMatrixOperator
.
I expect the _solve_beta
function when passed a SparseMatrixOperator
and a multi-column RHS, to solve the linear system for each column of the RHS smoothly similarly to how it works when a MatrixLinearOperator
is used. While the linear solve operation works fine for individual columns of y with a SparseMatrixOperator
, it throws an error when attempting to solve for multiple columns using vmap
, and produces the following error:
ValueError: vmap in_axes specification must be a tree prefix of the corresponding value, got specification ... (trimmed for brevity)
Here is a minimal working example provided specific to this problem:
from __future__ import annotations
import numpy as np
import equinox as eqx
import jax
import jax.lax as lax
import jax.numpy as jnp
import lineax as lx
from jax.experimental import sparse
from jaxtyping import Array, ArrayLike, Float
class SparseMatrixOperator(lx.AbstractLinearOperator):
matrix: sparse.JAXSparse
def __init__(self, matrix):
self.matrix = matrix
def mv(self, vector: ArrayLike):
return sparse.sparsify(jnp.matmul)(
self.matrix, vector, precision=lax.Precision.HIGHEST
)
def as_matrix(self) -> Float[Array, "n p"]:
return self.matrix.todense()
def transpose(self) -> "SparseMatrixOperator":
return SparseMatrixOperator(self.matrix.T)
def in_structure(self) -> jax.ShapeDtypeStruct:
_, in_size = self.matrix.shape
return jax.ShapeDtypeStruct((in_size,), self.matrix.dtype)
def out_structure(self) -> jax.ShapeDtypeStruct:
out_size, _ = self.matrix.shape
return jax.ShapeDtypeStruct((out_size,), self.matrix.dtype)
@lx.is_negative_semidefinite.register(SparseMatrixOperator)
def _(op):
return False
@lx.linearise.register(SparseMatrixOperator)
def _(op):
return op
_multi_linear_solve = eqx.filter_vmap(lx.linear_solve, in_axes=(None, 1, None))
def _solve_beta(A: lx.AbstractLinearOperator, rhs: ArrayLike) -> Array:
solver = lx.NormalCG(rtol=1e-6, atol=1e-6)
out = _multi_linear_solve(A, rhs, solver)
updated_beta = out.value.T
return updated_beta
N, P, K = 10, 3, 2
X = jnp.asarray(np.random.normal(size=(N, P)))
y = jnp.asarray(np.random.normal(size=(N, K)))
solver = lx.NormalCG(rtol=1e-6, atol=1e-6)
# Not actually dense, but just need an instance...
G_sp = SparseMatrixOperator(sparse.BCOO.fromdense(X))
G_op = lx.MatrixLinearOperator(X)
# Works in the following cases
print(lx.linear_solve(G_sp, y[:, 0], solver).value)
print(lx.linear_solve(G_sp, y[:, 1], solver).value)
print(_solve_beta(G_op, y))
# Does not work in this case
print(_solve_beta(G_sp, y))