lineax icon indicating copy to clipboard operation
lineax copied to clipboard

reverse-mode AD with Sparse Matrices through solve

Open bhorowitz opened this issue 9 months ago • 3 comments

Hello Everyone,

I've been very impressed with lineax and it seems like it might address many issues I've been having with my field (astrophysical simulations). I was hoping to back-propogate through a poisson solver (inspired by the conversation and pull request here: https://github.com/patrick-kidger/lineax/pull/120 ). I was running into some issues getting reverse mode AD to work through the solver where there is a sparse matrix multiplication, although jacfwd works well. However, jacfwd is very memory intensive for large input fields.

I have a simple example notebook here: https://github.com/bhorowitz/DiffAPM/blob/main/02-possion2d-deriv.ipynb

Perhaps there is a fundamental issue trying to do this sort of operation?

Thank you for your help! Ben Horowitz

bhorowitz avatar Mar 26 '25 12:03 bhorowitz

Your example shows an error that recommends using eqx.filter_grad over jax.grad, have you tried this?

johannahaffner avatar Mar 26 '25 15:03 johannahaffner

Your example shows an error that recommends using eqx.filter_grad over jax.grad, have you tried this?

I think you might be misreading the error message, this is one recommending using eqx.filter_jit over jax.jit (as if this is done then this fairly terrifying-looking error message gets much more readable).

I was running into some issues getting reverse mode AD to work through the solver where there is a sparse matrix multiplication, although jacfwd works well. However, jacfwd is very memory intensive for large input fields.

As for the error itself! grad-of-linsolve is another linsolve, and it is one that we unconditionally set throw=True for (as on the backward pass there is nowhere we can pipe an error code to). In this case it looks like the linear system that is set up up is one which GMRES is unable to solve. (And possibly the forward system couldn't be solved either, but here you have throw=False so the failure is returned only via sol = lx.linear_solve(...); sol.result).

That's not too surprising unfortunately, most of the matrix-free iterative solvers are designed to only work on certain problems (largely those arising in certain PDE settings) and often require preconditioning. That's not a Lineax thing, that's a property of the algorithm itself.

patrick-kidger avatar Mar 26 '25 19:03 patrick-kidger

I think you might be misreading the error message

Indeed. Sorry, my bad!

Re: problem structures: shouldn't this be one of the cases where this solver is actually a good match? (Laplacian operators, ...)

johannahaffner avatar Mar 26 '25 21:03 johannahaffner