Query: How does Lineax write the JVP rule for linear solvers?
Hey guys,
Thanks for providing a nice ecosystem (equinox, optimistic, lineax...) to enable complex scientific computations. I am in the process of learning to write custom AD rules for linear solvers, specifically for solvers from external libraries. I know that the JVP-rule is simple to write and more useful atm (since JAX has not exposed the control of transpose rules https://github.com/jax-ml/jax/issues/9129 or https://github.com/jax-ml/jax/pull/17840), since it gives both forward and reverse-mode AD.
The JVP rule would look like
@jax.custom_jvp
def solve(A, b):
# Solve Ax=b
x = some_solver(A, b)
return x
@solve.def_jvp()
def solve_jvp(primals, tangents):
A_dot, b_dot = tangents
x = solve(*primals)
# A_dot x + A x_dot = b_dot
# x_dot = A_inv (b_dot - A_dot x)
x_dot = some_solver(A, b_dot - A_dot @ x)
return x, x_dot
The issue is that this fails when the solver is external even if we use pure_callback.
I had a discussion about it in jax (https://github.com/jax-ml/jax/discussions/25528)
But the end result was that it would require a full-blown primitive.
But if my understanding is correct, Lineax does this somehow. Can you give insights on how this was achieved.
P.S. The end goal for me to also add sparsity into the mix
Lineax's JVP rule is defined here. We actually do define a custom primitive, rather than using jax.custom_jvp, because we need to define a custom transposition rule. :)
Fortunately for your case, this has all been handled for you in a solver-agnostic way. You should be able to just implement a lineax.AbstractLinearSolver and then things will work for you from there.
Thanks Patrick. I will attempt to do that in this thread (and create a MWE), and then close this so that it could be useful for others as well.
Hey @SNMS95, I realised that I somehow cross-posted on a very similar topic (https://github.com/patrick-kidger/lineax/issues/173). Did you manage to come up with an implementation?
Hey @vboussange ,
I did, it looks something like this: Since I had to deal with sparsity on top as well, I first made a SparseOperator to wrap jax.BCOO. If you are working with mat-vecs directly, you can use one of the other operators from lineax (I think FunctionalLinearOpertor?). The idea then is to use pure_callbacks to perform calls to the external package. Since AbstractLinearSolver is wrapped as a primitive, AD will work as usual and since we used pure_callback, JIT also works well!
class SparseMatrixLinearOperator(lx.AbstractLinearOperator, strict=True):
bcoo_matrix: BCOO
def __init__(self, bcoo_matrix):
self.bcoo_matrix = bcoo_matrix
def mv(self, x):
return self.bcoo_matrix @ x
def as_matrix(self):
return self.bcoo_matrix.todense()
def transpose(self):
return self # SparseMatrixLinearOperator(self.bcoo_matrix.transpose())
def in_structure(self):
_, in_size = self.bcoo_matrix.shape
return jax.ShapeDtypeStruct(shape=(in_size,),
dtype=self.bcoo_matrix.dtype)
def out_structure(self):
out_size, _ = self.bcoo_matrix.shape
return jax.ShapeDtypeStruct(shape=(out_size,),
dtype=self.bcoo_matrix.dtype)
@lx.is_symmetric.register(SparseMatrixLinearOperator)
def _(operator):
return True
@lx.linearise.register(SparseMatrixLinearOperator)
def _(operator):
return operator
class CVXOPTSolver(lx.AbstractLinearSolver):
symmetrsize_before_solve: bool = False
def __init__(self, symmetrsize_before_solve=False):
self.symmetrsize_before_solve = symmetrsize_before_solve
def init(self, operator, options):
return operator # return the solver state
def compute(self, solver_state, rhs, options):
bcoo_matrix = solver_state.bcoo_matrix
data = bcoo_matrix.data
indices = bcoo_matrix.indices
def host_solve(data, indices, rhs):
data = np.array(data).astype(rhs.dtype)
indices = np.array(indices.T)
rhs = np.array(rhs)
K = cvxopt.spmatrix(data, indices[0, :], indices[1, :])
if self.symmetrsize_before_solve:
K = (K + K.T) / 2.0
B = cvxopt.matrix(rhs)
cvxopt.cholmod.linsolve(K, B)
return np.array(B).astype(rhs.dtype).reshape(rhs.shape)
# call the solver
result_shape_dtypes = jax.ShapeDtypeStruct(
jnp.broadcast_shapes(rhs.shape), rhs.dtype)
sol = jax.pure_callback(
host_solve,
result_shape_dtypes,
data,
indices,
rhs,
vmap_method="sequential")
return sol, lx.RESULTS.successful, {}
def allow_dependent_columns(self, operator):
return False
def allow_dependent_rows(self, operator):
return False
def conj(self, solver_state, options):
return solver_state.transpose(), options
def transpose(self, state, options):
return state.transpose(), options
Hey @SNMS95, thanks for the update!
I was looking into a solution where we could compute the internal state of the solver during the init phase and pass it on to the solver_state, so as to avoid overheads when making multiple solves.
In your case, multiple solves involve little overhead, so your approach makes total sense. You could optimise it slightly when symmetrsize_before_solve==True by converting K (your external solver's internal state) to JAX-compatible arrays (data, indices), which could in turn be stored as JAX-compatible states.
In my case, the external solver's init phase is a significant overhead, so it makes more sense to initialise the external solver only once (I am interested in multigrid solvers, and the internal state correspond to a multigrid hierarchy stored into a pyamg.MultilevelSolver object). The external solver's internal state (the pyamg.MultilevelSolver object) cannot be easily converted to a JAX-compatible struct, and I am unsure how to deal with this.
Hi @vboussange ,
If you use filter_pure_callback from equinox, you should be able to store python objects into the initial state as well.
Would be nice, if you crack this, could you post it here?
I believe that internally lineax uses "filtered" versions everywhere so python objects should be okay! If you want to say hi, my email is "[email protected]"
I could indeed make eqx.filter_pure_callback work with python objects as inputs (see https://github.com/patrick-kidger/lineax/issues/173), but the problem is that I could not use it to output arbitrary python objects. Hence, I am stuck with initialising the solver outside a JIT region, while it is necessary to be able to modify the state within a JIT region for benefitting from the differentiation rules, which call the transpose and conj methods.
I am sure @patrick-kidger would have a good hint on whether this would be possible.
So the only thing that can be passed around inside of JAX, at runtime, are JAX arrays. After all, the whole thing gets lowered to an XLA computation graph!
So you need to find a way to express your Python object within that restriction.
At least two approaches come to mind:
- If you are able to serialise your object into a fixed-length buffer of bytes, then you can serialise/deserialise.
- If your object is completely opaque, then (a) in your first
pure_callbackyou could record your object in a globaldict[int, YourObject], (b) inside of JAX pass around the integer key as a scalar JAX array, (c) inside laterpure_callbacks you can retrieve your object using the key, (d) make sure to clear out the object from the dictionary at some point after the program has finished running (or inside your secondpure_callback).
Right, thanks for the hints @patrick-kidger !
Could there be a third option involving vmap and batching, where the host_solve called with pure_callback within the solver.compute method initialises the external solver once, then explicitly loops through a batch of inputs, and returns a batched vector of solutions? This would be a particularly attractive option for external solvers which have specialised algorithms for batched linear problems.
However, I am really not sure how to interface this batched approach with lineax. Could this option be viable with some tricks to comply with the lineax interface?
I think handling batching is a separate concern. It looks like you need to both (a) have a way for multiple pure_callbacks to communicate with each other (as in my previous message) and (b) have a way to handle batch-of-solves.
FWIW, I think batching should be pretty straightforward. Lineax itself does nothing itself here -- it simply calls jax.vmap on the implementation provided by your solver. Assuming that that solver is itself using jax.pure_callback, then you'll simply end up needing to set vmap_method appropriately.