jax
jax copied to clipboard
Differentiable matrix-free linear algebra, optimization and equation solving
This is a meta-issue for keeping track of progress on implementing differentiable higher-order functions from SciPy in JAX, e.g.,
- [ ]
scipy.sparse.linalg.gmres
andcg
: matrix-free linear solves - [ ]
scipy.sparse.linalg.eigs
andeigsh
: matrix-free eigenvalue problems - [ ]
scipy.optimize.root
: nonlinear equation solving - [ ]
scipy.optimize.fixed_point
: solving for fixed points - [ ]
scipy.integrate.odeint
: solving ordinary differential equations - [ ]
scipy.optimize.minimize
: nonlinear minimization
These higher-order functions are important for implementing sophisticated differentiable programs, both for scientific applications and for machine learning.
Implementations should leverage and build upon JAX's custom transformation capabilities. For example, scipy.optimize.root
should leverage autodiff for calculating the Jacobians or Jacobian-vector products needed for Newton's method.
In most cases, I think the right way to do this involves two separate steps, which could happens in parallel:
- Higher order primitives for defining automatic differentiation rules, but not specialized to any particular algorithm, e.g.,
lax.custom_linear_solve
from https://github.com/google/jax/pull/1402. - Implementations of particular algorithms for the forward problems, e.g., a conjugate gradient method for linear solves. These could either be implemented from scratch using JAX's functional control flow (e.g.,
while_loop
) or could leverage existing external implementations on particular backends. Either way they will almost certainly need custom derivative rules, rather than differentiation through the forward algorithm.
There's lots of work to be done here, so please comment if you're interested in using or implementing any of these.
I'll start things off.
I've written initial versions of higher order primitives defining gradients for non-linear root finding (see discussion in https://github.com/google/jax/issues/1448) and linear equation solving (https://github.com/google/jax/pull/1402).
I've also written experimental implementations in JAX for gmres
and a Thick-Restart Lanczos method (for eigsh
). These are not shared publicly yet, but I could drop them into jax.experimental
. (EDIT: see https://github.com/google/jax/pull/3114 for thick-restart lanczos)
@shoyer With this addition and, in general, implicit diff related features, are there any plans for a mechanism to extract intermediate/aux values when differentiating to allow us to log things like number of iterations for tangent solve, residual, etc.? (related: #1197 #844)
Yes, I'm working on a general feature for referring to (and extracting, injecting, differentiating w.r.t.) intermediate/aux values; it's coming soon :slightly_smiling_face:.
I'm glad to hear @jekbradbury is thinking about, because I have not! I agree it's important.
In this case, auxiliary outputs should not have derivatives defined. If we were to solve this entirely inside custom_linear_solve
etc then I would suggest simply adding an explicit has_aux
argument that changes the function signature to return an extra argument. But then it isn't obvious how we could pipe out auxiliary outputs from the forward or transpose passes.
@jekbradbury Awesome, looking forward to it!
(@ all JAX maintainers/contributors) I'm loving style and the direction of JAX, keep up the great work!
#2566 adds cg
, so at least that's a start.
@shoyer, what's the status of gmres
? I started its implementation today but then I realized you have something already.
I have a very naive implementation of GMRES with preconditioning that you can find here: https://gist.github.com/shoyer/cbac2cf8c8675b2f3a45e4837e3bed80
It needs more careful testing (and possibly improvements to the numerics) before merging into JAX.
To be clear, I have no immediate plans to continue work on my gmres
solver. If you want to take on this on, that would be fantastic!
@shoyer thanks for sharing! I think it would be nice to combine your implementation with the dot product between a sparse matrix and a vector #3717. The jit/GPU implementation still can't beat Scipy and I suspect this is due to the COO representation of the sparse matrix (Scipy uses CSR https://github.com/scipy/scipy/blob/v1.5.1/scipy/sparse/base.py#L532). I will do some testing in this direction first.
@romanodev this is really awesome work.
@shoyer thanks for sharing! I think it would be nice to combine your implementation with the dot product between a sparse matrix and a vector #3717. The jit/GPU implementation still can't beat Scipy and I suspect this is due to the COO representation of the sparse matrix (Scipy uses CSR https://github.com/scipy/scipy/blob/v1.5.1/scipy/sparse/base.py#L532). I will do some testing in this direction first.
This other PR has a vectorized version of Gram Schmidt, which I think could replace _inner
in my implementation of GMRES above: https://github.com/google/jax/pull/3114 (it is basically the same algorithm)
OK, here's a version with vectorized Gram Schmidt, which is perhaps 10-20% faster on CPU and ~10x faster on GPUs: https://gist.github.com/shoyer/dc33a5850337b6a87d48ed97b4727d29
The GPU is still speeding most of its time waiting when solving for a 100-dimensional vector, but for large enough solves the performance should be reasonable.
I suspect the main remaining improvements (which we should probably have before merging) would adding some form of early termination based on residual values.
@shoyer, great. Any reason for not using experimental.loops?
Just for the sake of prototyping, I rewrote your version using loops (just personal taste for faster iterations). Using jnp.lstsq, I plotted the residual at each iteration. Within loops we can easily handle early termination, but I am not sure how to do it with lax.scan.
ln any case, the preconditioning does not seem to work yet. I considered a simple case where we take the inversion of the diagonals of A, and it doesn't match scipy. Although this choice is not justified since A does not have a dominant diagonal, it should at least serve for testing.
Here is the gist:
https://gist.github.com/romanodev/e3f6bd23c499cd8a5f26b26c140abcac
Any reason for not using experimental.loops?
It's just a matter of taste. Personally I don't find it much clearer than using functions if there's only one level of loops.
It's also a little easier to avoid inadvertently using Python control flow the control flow functions, e.g., in your example there's one place where you should be using s.range()
instead of range()
(at least in the "real" version, maybe you avoided that intentionally for printing?).
For printing intermediate outputs from inside compiled code like "for" loops, take a look at jax.experimental.host_callback
: https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html
For early termination logic in general, take a look at the source code for cg
. The loop needs to be written in terms of while_loop
instead of scan
:
https://github.com/google/jax/blob/0d81e988d8884a95ced928b026fec003bb34af57/jax/scipy/sparse/linalg.py#L54
Hi, in parallel with this, I'd like to add an implementation for bicgstab (which i believe is also matrix free)! I've started working on something similar in cupy so I thought I might as well add it to jax
.
Sure, we would love to see an implementation of bicgstab. The implementation should be relatively straightforward (easier than gmres). Please copy the style of the existing cg implementation, including tests.
On Thu, Jul 16, 2020 at 7:39 PM sunilkpai [email protected] wrote:
Hi, in parallel with this, I'd like to add an implementation for bicgstab (which i believe is also matrix free)! I've started working on something similar in cupy https://github.com/cupy/cupy/pull/3569 so I thought I might as well add it to jax.
— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/1531#issuecomment-659800725, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJJFVUGIKRE6JHJFO25TBLR362XZANCNFSM4JCVYSAA .
Excellent! If GMRES works out too, then we can even do a more general form of BiCGSTAB, which may be useful for EM problems (as suggested by meep). Not sure if this is implemented in scipy, so unsure where something like that would go?
Meanwhile, as a reference, this is what I have so far on GMRES (using while_loop
, still needs to be tested carefully)
https://gist.github.com/romanodev/be02bd4b7e90c5ebb3dc84ebebf4e76f
With regards to the "more general version of GMRES", I'm sure this could be useful but my preference would be to stick to established algorithms in JAX. Inclusion in SciPy is one good indication of that. I would suggest releasing more experimental algorithms in separate packages.
Hi @shoyer, I have a general question as I prepare this PR. What is the motivation of defining pytree functions like _add
and _sub
in the _cg_solve
method? Is this for jit compilation? It appears I will need to add a more general _vdot_tree
function for bicgstab to work for more general matrices, which is why I'm asking!
Hi @shoyer, I have a general question as I prepare this PR. What is the motivation of defining pytree functions like
_add
and_sub
in the_cg_solve
method? Is this for jit compilation? It appears I will need to add a more general_vdot_tree
function for bicgstab to work for more general matrices, which is why I'm asking!
The reason for doing this is that it lets us inverting linear equations on arbitrary shaped arrays in arbitrary structures. In some cases, this is a much more natural way to implement linear operators, e.g., you might use a handful of 3D arrays for representing a solution to a PDE on a grid.
By supporting pytrees, we don't need to copy these arrays into a single flattened vector. This can actually add significant overhead due to extra copying, e.g., it's 60-90% slower to solve a Poisson equation in 2D using CG with a flattened array than with a 2D array: https://gist.github.com/shoyer/6826d02949e4d2ce82122a8bd5c62cf7
That said, these algorithms are much easier to implement/verify on a single vectors, and in the long term I'd like to solve vectorization with a tree_vectorize
code transformation instead -- see https://github.com/google/jax/pull/3263. This will hopefully be merged into JAX within the next month or so.
So if you prefer, I would also be OK with implementing these algorithms on the single vectors and adding explicit flattening/unflattening to handle pytrees. You can find an example of doing this sort of thing inside _odeint_wrapper
-- you could basically use the exact same thing for something like cg
by omitting the jax.vmap
at the end:
https://github.com/google/jax/blob/fa2a0275c83c03cfad9d36f5da06b0bf47eedfb9/jax/experimental/ode.py#L210-L214
@shoyer I added a PR for bicgstab
but it's still WIP, was just looking for some comments before I finalize the implementation and tests but I think it's close! I think changing to the flattened version would be easy to extend to both cg
and bicgstab
, so we could handle that in a separate PR?
I would just like to add another comment on @shoyer's on supporting operators on pytree arrays. I have a use case where I need to compute the vector norm / dot product for arrays structured as pytrees, it would be very useful to support these operations.
I think changing to the flattened version would be easy to extend to both
cg
andbicgstab
, so we could handle that in a separate PR?
To be clear, this was a suggestion for making it easier to write new solvers. Actually, it would even be fine not to support pytrees at all on new solvers.
cg
already supports pytrees without flattening, and we should keep this functionality!
Hi @shoyer, re: #3796 I think we should think about either a more robust testing pipeline for all matrix-free methods or just copy scipy
's tests and use that as our standard. What do you think are the appropriate next steps?
Hi @shoyer, re: #3796 I think we should think about either a more robust testing pipeline for all matrix-free methods or just copy
scipy
's tests and use that as our standard. What do you think are the appropriate next steps?
We should absolutely feel free to copy SciPy's tests. https://github.com/google/jax/pull/3101 (which I will be merging shortly) has a good example of how to do this.
Functioning and apprently efficient Jax implementations of eigs, eigsh using implicitly restarted Lanczos, and gmres are already present in https://github.com/google/TensorNetwork (matrix-free methods are very important in tensor network computations), and perhaps it would make more sense to merge some or all of these into Jax. I've started a conversation in https://github.com/google/TensorNetwork/issues/785 with the other TensorNetwork devs on the matter.
We are also planning to implement LGMRES, which is a modification to GMRES that performs better on nearly-symmetric operators. This would probably be my job since I wrote the GMRES implementation, and I'd be happy to simply write it here instead.
Note that both LGMRES and bicgstab are accessible through SciPy.
Functioning and apprently efficient Jax implementations of eigs, eigsh using implicitly restarted Lanczos, and gmres are already present in https://github.com/google/TensorNetwork (matrix-free methods are very important in tensor network computations), and perhaps it would make more sense to merge some or all of these into Jax
Agreed, I would love to upstream all of these into JAX.