lineax icon indicating copy to clipboard operation
lineax copied to clipboard

Issue with `vmap` when using `lx.linear_solve` on `SparseMatrixOperator` with multi-column RHS

Open Dong555 opened this issue 8 months ago • 1 comments

I'm encountering a ValueError when using the vmap functionality to map the lx.linear_solve operation across multiple columns of a right-hand side matrix (RHS) with a SparseMatrixOperator.

I expect the _solve_beta function when passed a SparseMatrixOperator and a multi-column RHS, to solve the linear system for each column of the RHS smoothly similarly to how it works when a MatrixLinearOperator is used. While the linear solve operation works fine for individual columns of y with a SparseMatrixOperator, it throws an error when attempting to solve for multiple columns using vmap, and produces the following error: ValueError: vmap in_axes specification must be a tree prefix of the corresponding value, got specification ... (trimmed for brevity)

Here is a minimal working example provided specific to this problem:

from __future__ import annotations

import numpy as np
import equinox as eqx
import jax
import jax.lax as lax
import jax.numpy as jnp
import lineax as lx
from jax.experimental import sparse
from jaxtyping import Array, ArrayLike, Float

class SparseMatrixOperator(lx.AbstractLinearOperator):
    matrix: sparse.JAXSparse

    def __init__(self, matrix):
        self.matrix = matrix

    def mv(self, vector: ArrayLike):
        return sparse.sparsify(jnp.matmul)(
            self.matrix, vector, precision=lax.Precision.HIGHEST
        )

    def as_matrix(self) -> Float[Array, "n p"]:
        return self.matrix.todense()

    def transpose(self) -> "SparseMatrixOperator":
        return SparseMatrixOperator(self.matrix.T)

    def in_structure(self) -> jax.ShapeDtypeStruct:
        _, in_size = self.matrix.shape
        return jax.ShapeDtypeStruct((in_size,), self.matrix.dtype)

    def out_structure(self) -> jax.ShapeDtypeStruct:
        out_size, _ = self.matrix.shape
        return jax.ShapeDtypeStruct((out_size,), self.matrix.dtype)

@lx.is_negative_semidefinite.register(SparseMatrixOperator)
def _(op):
    return False

@lx.linearise.register(SparseMatrixOperator)
def _(op):
    return op

_multi_linear_solve = eqx.filter_vmap(lx.linear_solve, in_axes=(None, 1, None))

def _solve_beta(A: lx.AbstractLinearOperator, rhs: ArrayLike) -> Array:
    solver = lx.NormalCG(rtol=1e-6, atol=1e-6)
    out = _multi_linear_solve(A, rhs, solver)
    updated_beta = out.value.T

    return updated_beta

N, P, K = 10, 3, 2
X = jnp.asarray(np.random.normal(size=(N, P)))
y = jnp.asarray(np.random.normal(size=(N, K)))
solver = lx.NormalCG(rtol=1e-6, atol=1e-6)

# Not actually dense, but just need an instance...
G_sp = SparseMatrixOperator(sparse.BCOO.fromdense(X))
G_op = lx.MatrixLinearOperator(X)

# Works in the following cases
print(lx.linear_solve(G_sp, y[:, 0], solver).value)
print(lx.linear_solve(G_sp, y[:, 1], solver).value)
print(_solve_beta(G_op, y))

# Does not work in this case
print(_solve_beta(G_sp, y))

Dong555 avatar Oct 23 '23 21:10 Dong555