lineax icon indicating copy to clipboard operation
lineax copied to clipboard

Query: How does Lineax write the JVP rule for linear solvers?

Open SNMS95 opened this issue 10 months ago • 10 comments

Hey guys,

Thanks for providing a nice ecosystem (equinox, optimistic, lineax...) to enable complex scientific computations. I am in the process of learning to write custom AD rules for linear solvers, specifically for solvers from external libraries. I know that the JVP-rule is simple to write and more useful atm (since JAX has not exposed the control of transpose rules https://github.com/jax-ml/jax/issues/9129 or https://github.com/jax-ml/jax/pull/17840), since it gives both forward and reverse-mode AD.

The JVP rule would look like

@jax.custom_jvp
def solve(A, b):
    # Solve Ax=b
    x = some_solver(A, b)
    return x

@solve.def_jvp()
def solve_jvp(primals, tangents):
    A_dot, b_dot = tangents
    x = solve(*primals)
    # A_dot x + A x_dot = b_dot
    # x_dot = A_inv (b_dot - A_dot x)
    x_dot = some_solver(A, b_dot - A_dot @ x)
    return x, x_dot

The issue is that this fails when the solver is external even if we use pure_callback. I had a discussion about it in jax (https://github.com/jax-ml/jax/discussions/25528) But the end result was that it would require a full-blown primitive.

But if my understanding is correct, Lineax does this somehow. Can you give insights on how this was achieved.

P.S. The end goal for me to also add sparsity into the mix

SNMS95 avatar Feb 14 '25 10:02 SNMS95

Lineax's JVP rule is defined here. We actually do define a custom primitive, rather than using jax.custom_jvp, because we need to define a custom transposition rule. :)

Fortunately for your case, this has all been handled for you in a solver-agnostic way. You should be able to just implement a lineax.AbstractLinearSolver and then things will work for you from there.

patrick-kidger avatar Feb 15 '25 08:02 patrick-kidger

Thanks Patrick. I will attempt to do that in this thread (and create a MWE), and then close this so that it could be useful for others as well.

SNMS95 avatar Feb 17 '25 08:02 SNMS95

Hey @SNMS95, I realised that I somehow cross-posted on a very similar topic (https://github.com/patrick-kidger/lineax/issues/173). Did you manage to come up with an implementation?

vboussange avatar Oct 01 '25 05:10 vboussange

Hey @vboussange ,

I did, it looks something like this: Since I had to deal with sparsity on top as well, I first made a SparseOperator to wrap jax.BCOO. If you are working with mat-vecs directly, you can use one of the other operators from lineax (I think FunctionalLinearOpertor?). The idea then is to use pure_callbacks to perform calls to the external package. Since AbstractLinearSolver is wrapped as a primitive, AD will work as usual and since we used pure_callback, JIT also works well!

class SparseMatrixLinearOperator(lx.AbstractLinearOperator, strict=True):
    bcoo_matrix: BCOO

    def __init__(self, bcoo_matrix):
        self.bcoo_matrix = bcoo_matrix

    def mv(self, x):
        return self.bcoo_matrix @ x

    def as_matrix(self):
        return self.bcoo_matrix.todense()

    def transpose(self):
        return self  # SparseMatrixLinearOperator(self.bcoo_matrix.transpose())

    def in_structure(self):
        _, in_size = self.bcoo_matrix.shape
        return jax.ShapeDtypeStruct(shape=(in_size,),
                                    dtype=self.bcoo_matrix.dtype)

    def out_structure(self):
        out_size, _ = self.bcoo_matrix.shape
        return jax.ShapeDtypeStruct(shape=(out_size,),
                                    dtype=self.bcoo_matrix.dtype)


@lx.is_symmetric.register(SparseMatrixLinearOperator)
def _(operator):
    return True


@lx.linearise.register(SparseMatrixLinearOperator)
def _(operator):
    return operator


class CVXOPTSolver(lx.AbstractLinearSolver):
    symmetrsize_before_solve: bool = False

    def __init__(self, symmetrsize_before_solve=False):
        self.symmetrsize_before_solve = symmetrsize_before_solve

    def init(self, operator, options):
        return operator  # return the solver state

    def compute(self, solver_state, rhs, options):
        bcoo_matrix = solver_state.bcoo_matrix
        data = bcoo_matrix.data
        indices = bcoo_matrix.indices

        def host_solve(data, indices, rhs):
            data = np.array(data).astype(rhs.dtype)
            indices = np.array(indices.T)
            rhs = np.array(rhs)
            K = cvxopt.spmatrix(data, indices[0, :], indices[1, :])
            if self.symmetrsize_before_solve:
                K = (K + K.T) / 2.0
            B = cvxopt.matrix(rhs)
            cvxopt.cholmod.linsolve(K, B)
            return np.array(B).astype(rhs.dtype).reshape(rhs.shape)

        # call the solver
        result_shape_dtypes = jax.ShapeDtypeStruct(
            jnp.broadcast_shapes(rhs.shape), rhs.dtype)
        sol = jax.pure_callback(
            host_solve,
            result_shape_dtypes,
            data,
            indices,
            rhs,
            vmap_method="sequential")
        return sol, lx.RESULTS.successful, {}

    def allow_dependent_columns(self, operator):
        return False

    def allow_dependent_rows(self, operator):
        return False

    def conj(self, solver_state, options):
        return solver_state.transpose(), options

    def transpose(self, state, options):
        return state.transpose(), options

SNMS95 avatar Oct 01 '25 08:10 SNMS95

Hey @SNMS95, thanks for the update! I was looking into a solution where we could compute the internal state of the solver during the init phase and pass it on to the solver_state, so as to avoid overheads when making multiple solves.

In your case, multiple solves involve little overhead, so your approach makes total sense. You could optimise it slightly when symmetrsize_before_solve==True by converting K (your external solver's internal state) to JAX-compatible arrays (data, indices), which could in turn be stored as JAX-compatible states.

In my case, the external solver's init phase is a significant overhead, so it makes more sense to initialise the external solver only once (I am interested in multigrid solvers, and the internal state correspond to a multigrid hierarchy stored into a pyamg.MultilevelSolver object). The external solver's internal state (the pyamg.MultilevelSolver object) cannot be easily converted to a JAX-compatible struct, and I am unsure how to deal with this.

vboussange avatar Oct 01 '25 08:10 vboussange

Hi @vboussange ,

If you use filter_pure_callback from equinox, you should be able to store python objects into the initial state as well. Would be nice, if you crack this, could you post it here?

I believe that internally lineax uses "filtered" versions everywhere so python objects should be okay! If you want to say hi, my email is "[email protected]"

SNMS95 avatar Oct 01 '25 09:10 SNMS95

I could indeed make eqx.filter_pure_callback work with python objects as inputs (see https://github.com/patrick-kidger/lineax/issues/173), but the problem is that I could not use it to output arbitrary python objects. Hence, I am stuck with initialising the solver outside a JIT region, while it is necessary to be able to modify the state within a JIT region for benefitting from the differentiation rules, which call the transpose and conj methods.

I am sure @patrick-kidger would have a good hint on whether this would be possible.

vboussange avatar Oct 01 '25 09:10 vboussange

So the only thing that can be passed around inside of JAX, at runtime, are JAX arrays. After all, the whole thing gets lowered to an XLA computation graph!

So you need to find a way to express your Python object within that restriction.

At least two approaches come to mind:

  • If you are able to serialise your object into a fixed-length buffer of bytes, then you can serialise/deserialise.
  • If your object is completely opaque, then (a) in your first pure_callback you could record your object in a global dict[int, YourObject], (b) inside of JAX pass around the integer key as a scalar JAX array, (c) inside later pure_callbacks you can retrieve your object using the key, (d) make sure to clear out the object from the dictionary at some point after the program has finished running (or inside your second pure_callback).

patrick-kidger avatar Oct 01 '25 10:10 patrick-kidger

Right, thanks for the hints @patrick-kidger !

Could there be a third option involving vmap and batching, where the host_solve called with pure_callback within the solver.compute method initialises the external solver once, then explicitly loops through a batch of inputs, and returns a batched vector of solutions? This would be a particularly attractive option for external solvers which have specialised algorithms for batched linear problems.

However, I am really not sure how to interface this batched approach with lineax. Could this option be viable with some tricks to comply with the lineax interface?

vboussange avatar Oct 01 '25 15:10 vboussange

I think handling batching is a separate concern. It looks like you need to both (a) have a way for multiple pure_callbacks to communicate with each other (as in my previous message) and (b) have a way to handle batch-of-solves.

FWIW, I think batching should be pretty straightforward. Lineax itself does nothing itself here -- it simply calls jax.vmap on the implementation provided by your solver. Assuming that that solver is itself using jax.pure_callback, then you'll simply end up needing to set vmap_method appropriately.

patrick-kidger avatar Oct 03 '25 12:10 patrick-kidger