jax icon indicating copy to clipboard operation
jax copied to clipboard

Differentiable matrix-free linear algebra, optimization and equation solving

Open shoyer opened this issue 4 years ago • 38 comments

This is a meta-issue for keeping track of progress on implementing differentiable higher-order functions from SciPy in JAX, e.g.,

  • [ ] scipy.sparse.linalg.gmres and cg: matrix-free linear solves
  • [ ] scipy.sparse.linalg.eigs and eigsh: matrix-free eigenvalue problems
  • [ ] scipy.optimize.root: nonlinear equation solving
  • [ ] scipy.optimize.fixed_point: solving for fixed points
  • [ ] scipy.integrate.odeint: solving ordinary differential equations
  • [ ] scipy.optimize.minimize: nonlinear minimization

These higher-order functions are important for implementing sophisticated differentiable programs, both for scientific applications and for machine learning.

Implementations should leverage and build upon JAX's custom transformation capabilities. For example, scipy.optimize.root should leverage autodiff for calculating the Jacobians or Jacobian-vector products needed for Newton's method.

In most cases, I think the right way to do this involves two separate steps, which could happens in parallel:

  1. Higher order primitives for defining automatic differentiation rules, but not specialized to any particular algorithm, e.g., lax.custom_linear_solve from https://github.com/google/jax/pull/1402.
  2. Implementations of particular algorithms for the forward problems, e.g., a conjugate gradient method for linear solves. These could either be implemented from scratch using JAX's functional control flow (e.g., while_loop) or could leverage existing external implementations on particular backends. Either way they will almost certainly need custom derivative rules, rather than differentiation through the forward algorithm.

There's lots of work to be done here, so please comment if you're interested in using or implementing any of these.

shoyer avatar Oct 20 '19 19:10 shoyer

I'll start things off.

I've written initial versions of higher order primitives defining gradients for non-linear root finding (see discussion in https://github.com/google/jax/issues/1448) and linear equation solving (https://github.com/google/jax/pull/1402).

I've also written experimental implementations in JAX for gmres and a Thick-Restart Lanczos method (for eigsh). These are not shared publicly yet, but I could drop them into jax.experimental. (EDIT: see https://github.com/google/jax/pull/3114 for thick-restart lanczos)

shoyer avatar Oct 20 '19 19:10 shoyer

@shoyer With this addition and, in general, implicit diff related features, are there any plans for a mechanism to extract intermediate/aux values when differentiating to allow us to log things like number of iterations for tangent solve, residual, etc.? (related: #1197 #844)

gehring avatar Oct 30 '19 17:10 gehring

Yes, I'm working on a general feature for referring to (and extracting, injecting, differentiating w.r.t.) intermediate/aux values; it's coming soon :slightly_smiling_face:.

jekbradbury avatar Oct 30 '19 18:10 jekbradbury

I'm glad to hear @jekbradbury is thinking about, because I have not! I agree it's important.

In this case, auxiliary outputs should not have derivatives defined. If we were to solve this entirely inside custom_linear_solve etc then I would suggest simply adding an explicit has_aux argument that changes the function signature to return an extra argument. But then it isn't obvious how we could pipe out auxiliary outputs from the forward or transpose passes.

shoyer avatar Oct 30 '19 18:10 shoyer

@jekbradbury Awesome, looking forward to it!

(@ all JAX maintainers/contributors) I'm loving style and the direction of JAX, keep up the great work!

gehring avatar Oct 30 '19 18:10 gehring

#2566 adds cg, so at least that's a start.

shoyer avatar Apr 05 '20 00:04 shoyer

@shoyer, what's the status of gmres? I started its implementation today but then I realized you have something already.

romanodev avatar Jul 06 '20 04:07 romanodev

I have a very naive implementation of GMRES with preconditioning that you can find here: https://gist.github.com/shoyer/cbac2cf8c8675b2f3a45e4837e3bed80

It needs more careful testing (and possibly improvements to the numerics) before merging into JAX.

shoyer avatar Jul 07 '20 21:07 shoyer

To be clear, I have no immediate plans to continue work on my gmres solver. If you want to take on this on, that would be fantastic!

shoyer avatar Jul 07 '20 21:07 shoyer

@shoyer thanks for sharing! I think it would be nice to combine your implementation with the dot product between a sparse matrix and a vector #3717. The jit/GPU implementation still can't beat Scipy and I suspect this is due to the COO representation of the sparse matrix (Scipy uses CSR https://github.com/scipy/scipy/blob/v1.5.1/scipy/sparse/base.py#L532). I will do some testing in this direction first.

romanodev avatar Jul 11 '20 14:07 romanodev

@romanodev this is really awesome work.

mattjj avatar Jul 11 '20 20:07 mattjj

@shoyer thanks for sharing! I think it would be nice to combine your implementation with the dot product between a sparse matrix and a vector #3717. The jit/GPU implementation still can't beat Scipy and I suspect this is due to the COO representation of the sparse matrix (Scipy uses CSR https://github.com/scipy/scipy/blob/v1.5.1/scipy/sparse/base.py#L532). I will do some testing in this direction first.

This other PR has a vectorized version of Gram Schmidt, which I think could replace _inner in my implementation of GMRES above: https://github.com/google/jax/pull/3114 (it is basically the same algorithm)

shoyer avatar Jul 11 '20 21:07 shoyer

OK, here's a version with vectorized Gram Schmidt, which is perhaps 10-20% faster on CPU and ~10x faster on GPUs: https://gist.github.com/shoyer/dc33a5850337b6a87d48ed97b4727d29

The GPU is still speeding most of its time waiting when solving for a 100-dimensional vector, but for large enough solves the performance should be reasonable.

I suspect the main remaining improvements (which we should probably have before merging) would adding some form of early termination based on residual values.

shoyer avatar Jul 12 '20 02:07 shoyer

@shoyer, great. Any reason for not using experimental.loops?

Just for the sake of prototyping, I rewrote your version using loops (just personal taste for faster iterations). Using jnp.lstsq, I plotted the residual at each iteration. Within loops we can easily handle early termination, but I am not sure how to do it with lax.scan.

ln any case, the preconditioning does not seem to work yet. I considered a simple case where we take the inversion of the diagonals of A, and it doesn't match scipy. Although this choice is not justified since A does not have a dominant diagonal, it should at least serve for testing.

Here is the gist:

https://gist.github.com/romanodev/e3f6bd23c499cd8a5f26b26c140abcac

romanodev avatar Jul 12 '20 19:07 romanodev

Any reason for not using experimental.loops?

It's just a matter of taste. Personally I don't find it much clearer than using functions if there's only one level of loops.

It's also a little easier to avoid inadvertently using Python control flow the control flow functions, e.g., in your example there's one place where you should be using s.range() instead of range() (at least in the "real" version, maybe you avoided that intentionally for printing?).

For printing intermediate outputs from inside compiled code like "for" loops, take a look at jax.experimental.host_callback: https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html

For early termination logic in general, take a look at the source code for cg. The loop needs to be written in terms of while_loop instead of scan: https://github.com/google/jax/blob/0d81e988d8884a95ced928b026fec003bb34af57/jax/scipy/sparse/linalg.py#L54

shoyer avatar Jul 13 '20 06:07 shoyer

Hi, in parallel with this, I'd like to add an implementation for bicgstab (which i believe is also matrix free)! I've started working on something similar in cupy so I thought I might as well add it to jax.

sunilkpai avatar Jul 17 '20 02:07 sunilkpai

Sure, we would love to see an implementation of bicgstab. The implementation should be relatively straightforward (easier than gmres). Please copy the style of the existing cg implementation, including tests.

On Thu, Jul 16, 2020 at 7:39 PM sunilkpai [email protected] wrote:

Hi, in parallel with this, I'd like to add an implementation for bicgstab (which i believe is also matrix free)! I've started working on something similar in cupy https://github.com/cupy/cupy/pull/3569 so I thought I might as well add it to jax.

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/1531#issuecomment-659800725, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJJFVUGIKRE6JHJFO25TBLR362XZANCNFSM4JCVYSAA .

shoyer avatar Jul 17 '20 02:07 shoyer

Excellent! If GMRES works out too, then we can even do a more general form of BiCGSTAB, which may be useful for EM problems (as suggested by meep). Not sure if this is implemented in scipy, so unsure where something like that would go?

sunilkpai avatar Jul 17 '20 03:07 sunilkpai

Meanwhile, as a reference, this is what I have so far on GMRES (using while_loop, still needs to be tested carefully)

https://gist.github.com/romanodev/be02bd4b7e90c5ebb3dc84ebebf4e76f

romanodev avatar Jul 17 '20 04:07 romanodev

With regards to the "more general version of GMRES", I'm sure this could be useful but my preference would be to stick to established algorithms in JAX. Inclusion in SciPy is one good indication of that. I would suggest releasing more experimental algorithms in separate packages.

shoyer avatar Jul 17 '20 05:07 shoyer

Hi @shoyer, I have a general question as I prepare this PR. What is the motivation of defining pytree functions like _add and _sub in the _cg_solve method? Is this for jit compilation? It appears I will need to add a more general _vdot_tree function for bicgstab to work for more general matrices, which is why I'm asking!

sunilkpai avatar Jul 18 '20 17:07 sunilkpai

Hi @shoyer, I have a general question as I prepare this PR. What is the motivation of defining pytree functions like _add and _sub in the _cg_solve method? Is this for jit compilation? It appears I will need to add a more general _vdot_tree function for bicgstab to work for more general matrices, which is why I'm asking!

The reason for doing this is that it lets us inverting linear equations on arbitrary shaped arrays in arbitrary structures. In some cases, this is a much more natural way to implement linear operators, e.g., you might use a handful of 3D arrays for representing a solution to a PDE on a grid.

By supporting pytrees, we don't need to copy these arrays into a single flattened vector. This can actually add significant overhead due to extra copying, e.g., it's 60-90% slower to solve a Poisson equation in 2D using CG with a flattened array than with a 2D array: https://gist.github.com/shoyer/6826d02949e4d2ce82122a8bd5c62cf7

That said, these algorithms are much easier to implement/verify on a single vectors, and in the long term I'd like to solve vectorization with a tree_vectorize code transformation instead -- see https://github.com/google/jax/pull/3263. This will hopefully be merged into JAX within the next month or so.

So if you prefer, I would also be OK with implementing these algorithms on the single vectors and adding explicit flattening/unflattening to handle pytrees. You can find an example of doing this sort of thing inside _odeint_wrapper -- you could basically use the exact same thing for something like cg by omitting the jax.vmap at the end: https://github.com/google/jax/blob/fa2a0275c83c03cfad9d36f5da06b0bf47eedfb9/jax/experimental/ode.py#L210-L214

shoyer avatar Jul 18 '20 18:07 shoyer

@shoyer I added a PR for bicgstab but it's still WIP, was just looking for some comments before I finalize the implementation and tests but I think it's close! I think changing to the flattened version would be easy to extend to both cg and bicgstab, so we could handle that in a separate PR?

sunilkpai avatar Jul 19 '20 21:07 sunilkpai

I would just like to add another comment on @shoyer's on supporting operators on pytree arrays. I have a use case where I need to compute the vector norm / dot product for arrays structured as pytrees, it would be very useful to support these operations.

ethanluoyc avatar Jul 20 '20 15:07 ethanluoyc

I think changing to the flattened version would be easy to extend to both cg and bicgstab, so we could handle that in a separate PR?

To be clear, this was a suggestion for making it easier to write new solvers. Actually, it would even be fine not to support pytrees at all on new solvers.

cg already supports pytrees without flattening, and we should keep this functionality!

shoyer avatar Jul 21 '20 03:07 shoyer

Hi @shoyer, re: #3796 I think we should think about either a more robust testing pipeline for all matrix-free methods or just copy scipy's tests and use that as our standard. What do you think are the appropriate next steps?

sunilkpai avatar Jul 29 '20 18:07 sunilkpai

Hi @shoyer, re: #3796 I think we should think about either a more robust testing pipeline for all matrix-free methods or just copy scipy's tests and use that as our standard. What do you think are the appropriate next steps?

We should absolutely feel free to copy SciPy's tests. https://github.com/google/jax/pull/3101 (which I will be merging shortly) has a good example of how to do this.

shoyer avatar Jul 29 '20 18:07 shoyer

Functioning and apprently efficient Jax implementations of eigs, eigsh using implicitly restarted Lanczos, and gmres are already present in https://github.com/google/TensorNetwork (matrix-free methods are very important in tensor network computations), and perhaps it would make more sense to merge some or all of these into Jax. I've started a conversation in https://github.com/google/TensorNetwork/issues/785 with the other TensorNetwork devs on the matter.

We are also planning to implement LGMRES, which is a modification to GMRES that performs better on nearly-symmetric operators. This would probably be my job since I wrote the GMRES implementation, and I'd be happy to simply write it here instead.

alewis avatar Aug 20 '20 16:08 alewis

Note that both LGMRES and bicgstab are accessible through SciPy.

alewis avatar Aug 20 '20 16:08 alewis

Functioning and apprently efficient Jax implementations of eigs, eigsh using implicitly restarted Lanczos, and gmres are already present in https://github.com/google/TensorNetwork (matrix-free methods are very important in tensor network computations), and perhaps it would make more sense to merge some or all of these into Jax

Agreed, I would love to upstream all of these into JAX.

shoyer avatar Aug 20 '20 17:08 shoyer