jax icon indicating copy to clipboard operation
jax copied to clipboard

sparse eigenvalue solvers

Open mganahl opened this issue 5 years ago • 47 comments

Hi!

Our team is currently developing sparse symmetric and non-symmetric solvers in JAX (implicitly restarted arnoldi and lanzcos). I have it from hearsay that there are efforts to support this natively in JAX. I wanted to ask what the status there is. It would be actually great if JAX would provide those.

mganahl avatar May 15 '20 22:05 mganahl

By "solvers" are you referring specifically to eigenvalue solvers?

Yes, I think we would be interested in adding some of these to JAX, but they would need to be reasonably robust and scalable, ideally comparable to the SciPy algorithms. The "basic" Lanczos method would probably not qualify, but restarted Arnoldi (used by SciPy in ARPACK) probably would.

It would also be nice to include auto-diff rules for these operations, but that could come later.

shoyer avatar May 15 '20 22:05 shoyer

I played around with writing a thick-restart Lanczos solver. As a point of reference, you can find it in https://github.com/google/jax/pull/3114 but I'm not sure we actually want to merge it.

shoyer avatar May 15 '20 23:05 shoyer

We have an implementation at https://github.com/google/TensorNetwork/blob/d4ec0a381dbf1a7d453d97525ccc857fc416b575/tensornetwork/backends/jax/jax_backend.py#L230

Not sure if it would useful to have in JAX proper.

chaserileyroberts avatar May 16 '20 00:05 chaserileyroberts

This version does not support restarts right now, and is in general geared towards tensornetwork applications. It could serve as a starting point for an implicitly restarted lanczos.

mganahl avatar May 16 '20 00:05 mganahl

I had a look at #3114, and this would indeed be interesting for us. The three most prevalent problems that pop up frequently in our applications are sparse symmetric (mostly SA or LA eigenvalues, usually only one to a few), sparse non-symmetric (with LR eigenvalues), and linear system solvers (symmetric and non-symmetric, i.e. gmres or lgmres), all of which we are currently working on. The requirement in terms of robustness and accuracy are often less rigorous for tensor networks than for other applications, hence our implementations are more mundane. But if the JAX team was interested as well in supporting those, it could make sense to join efforts.

mganahl avatar May 16 '20 01:05 mganahl

GMRES and L-GMRES for linear solves would also be quite welcome in JAX. These would fit in quite naturally alongside our new CG solver.

shoyer avatar May 16 '20 01:05 shoyer

It would also be nice to include auto-diff rules for these operations, but that could come later.

Just a note on this: here I have defined an autograd primitive wrapping scipy.sparse.linalg.eigsh (applied to regular Numpy matrices, not sparse ones). The vjp is almost identical to the one for numpy.linalg.eigh, but the summation in the backprop of eigenvector gradient is restricted only the numeig number of computed eigenvectors. This makes the eigenvector gradient approximate (becoming exact in the limit of numeig going to the linear size of the matrix). The gradient w.r.t. the eigenvalues however is exact. So you'll have to decide if you only support auto-diff w.r.t. the eigenvalues (I assume you don't want appoximate results...)

momchilmm avatar May 28 '20 00:05 momchilmm

Wowww cool, that is really helpful @momchilmm !

mattjj avatar May 28 '20 02:05 mattjj

Indeed, very cool to see! I also made a note with a few references about methods for calculating eigenvector derivatives from a partial decomposition over in https://github.com/google/jax/pull/3114.

shoyer avatar May 28 '20 02:05 shoyer

@shoyer oh this is very interesting! I might consider adding something like this in my autograd implementation because storing all (or many of) the eigenvectors to get the exact gradient is sometimes a significant memory overhead.

momchilmm avatar May 28 '20 02:05 momchilmm

@shoyer so I wrote a method to compute the eigenvector derivatives from a partial decomposition, along the lines of the works you had pointed out, and it works! Here's the vjp, and a test. The works mentioned in #3114 are overly complex because they consider degenerate eigenvalues, so I ended up following Steven Johnson's notes here, which I extended to the case of a Hermitian matrix with a small modification.

So yeah this seems to work for non-degenerate eigenvalues. I do think (similarly to you I believe) that the gradient is not a well-defined quantity in the case of degenerate eigenvalues, specifically when we have more than one input parameters. Basically, for a single parameter, you can define the derivative of each eigenvalue by choosing eigenvectors in your subspace that are also eigenvectors of the system with the corresponding small perturbation. However, this choice of eigenvectors will be different for every different input parameter (corresponding to a different matrix perturbation), and so a full gradient cannot be defined.

momchilmm avatar May 28 '20 23:05 momchilmm

~~@momchilmm I've been playing around with your vjp, and I occasionally get wildly inaccurate gradients. I'm guessing this is related to the cg-solve operating on a necessarily non-positive definite matrix. Have you enountered this? Any thoughts?~~

Update: this was due to an error with my eigenvector calculation (signs were inconsistent in evaluations for forward difference calculation I was using to check the gradient implementation).

jackd avatar Feb 02 '21 06:02 jackd

@jackd I think you are right! I have not encountered it myself since it seems to just work in some cases, but I can see how that might be a problem in others. In the Steven Johsnon notes there's a footnote:

Since P commutes with A−α, we can solve for λ 0 easily by an iterative method such as conjugate gradient: if we start with an initial guess orthogonal to x, all subsequent iterates will also be orthogonal to x and will thus converge to λ 0 (except for roundoff, which can be corrected by multiplying the final result by P).

However I don't think that would guarantee convergence, unless I'm missing something? I'm also not sure what's the best method to use for this matrix - we only know that it's Hermitian. But you could try just changing the solver to bicg or gmers and see if your troubles are gone. If you see it becoming stable I encourage you to submit a PR!

momchilmm avatar Feb 02 '21 17:02 momchilmm

@momchilmm glad I'm not going crazy :). I'm not having trouble with the smallest eigenvalue - and if I were I suspect a small diagonal shift might be enough to resolve it - but I am finding my tests often break as I increase the number of eigenvectors solved.

I'm looking at MINRES-QLP as a substitute, as it explicitly caters for singular symmetric matrices with pre-conditioning - alas there's no jax implementation of that though, so it's not just a one-line change :S.

jackd avatar Feb 02 '21 21:02 jackd

Yeah there doesn't seem to be anything in scipy.sparse.linalg that's expected to always work. I think the footnote that I quoted above means that in most cases you don't have to worry that the matrix is singular, because if you start with an initial guess that's orthogonal to the eigenvector (which spans the kernel), you'll always stay outside of the kernel. So the fact that the matrix is not positive-definite is the issue. The fact that you don't have trouble with the smallest eigenvalue is because in the iteration the vector is restricted to the subspace orthogonal to the corresponding eigenvector, and all the eigenvalues mapping to that space are larger than 0 (they are shifted by the smallest eigenvalue).

Minres-qld (or minres-qlp) seems to be the only thing I find for Hermitian matrices too. There's a freely available python implementation if you want to try your luck...

momchilmm avatar Feb 02 '21 21:02 momchilmm

Sigh, false alarm. Turns out my errors were due to my lobpcg implementation returning eigenvectors with different signs in the forward difference gradient check. I'll keep this in mind if I get any more NaN errors / gradient mis-matches in the future, but I have no evidence to suggest this is a problem at this stage. Thanks anyway :)

jackd avatar Feb 02 '21 22:02 jackd

Ah. Yeah. If your objective function depends on the sign (or complex phase more generally) of the eigenvectors, and you're not deterministically setting it in your solver, then you're in trouble yeah. That's one of the main reasons this current issue exists.

momchilmm avatar Feb 02 '21 22:02 momchilmm

If anyone's looking for a LOBPCG implementation I've hacked up a basic implementation here. The sparse implementation works without jitting - but I need to work out how to resolve this issue before the sparse jitted version will work.

There's also some vjps here based on @momchilmm's work above. Needs some more documentation, but hopefully the tests show integration with lobpcg.

jackd avatar Feb 03 '21 07:02 jackd

If anyone's looking for a LOBPCG implementation I've hacked up a basic implementation here. The sparse implementation works without jitting - but I need to work out how to resolve this issue before the sparse jitted version will work.

There's also some vjps here based on @momchilmm's work above. Needs some more documentation, but hopefully the tests show integration with lobpcg.

See also https://github.com/eserie/jju

lobpcg avatar Mar 28 '22 16:03 lobpcg

Hey @jackd , are you planning on putting together a PR to land a version of lobpcg into jax core?

vlad17 avatar Apr 01 '22 21:04 vlad17

@vlad17 no plan to right now. I'm willing to bet it wouldn't be a high priority from the jax core team even if it fit in the package (from memory mine is based on a pytorch implementation as opposed to the scipy one). Anyone else can feel free to use it however they want (e.g. for their own PR), though I haven't looked at it for a while so @lobpcg 's more recent adaption is probably in a better state.

jackd avatar Apr 01 '22 23:04 jackd

@lobpcg I've been toying with a simplified but fairly stable version (as in, seems to better than scipy) of LOBPCG which solves only the standard eigenvalue problem. I think JAX users would benefit from a PR which puts this into an experimental jax directory. Would you be able to help review it for correctness?

vlad17 avatar Apr 12 '22 18:04 vlad17

@mganahl jax.experimental.sparse.linalg.lobpcg_standard is in master. Right now it's top-k only, no jvp, so pretty barebones, but feel free to give it a spin and let me know what you think.

vlad17 avatar Jul 06 '22 03:07 vlad17

Thanks a ton for your work on this. jax.experimental.sparse has been generally really helpful for my projects & I appreciate that it's actively developed. I think something that would be nice is having an argument for a preconditioner + bottom-k eigenvalues. I saw that there are already some comments in master about it. I have been doing this with some success on big Laplacians & happy to submit a PR, but unsure if there are design considerations.

On something else - I know preconditioning has been discussed in the main matrix-free thread (#1531), but I also want to express my interest in something like an xla accelerated multigrid preconditioner.

choltz95 avatar Jan 02 '23 22:01 choltz95

bottom-k and preconditioner are reasonable next steps; I could add them (or review) if there are enough use cases that they'd be high priority.

@rmlarsen mentioned some interest in bottom-k, though not sure if he knows about possible users for it.

vlad17 avatar Jan 03 '23 01:01 vlad17

@rmlarsen @choltz95 The bottom-k of a matrix A is trivial to run in the existing code just run the code on negative matrix -A and change the sign of the eigenvalues to the opposite in the existing code at a level of the user.

@vlad17 Or modify the code to select the smallest eigenvalues instead of the largest in the RR; cf. option "largest" in https://github.com/scipy/scipy/blob/main/scipy/sparse/linalg/_eigen/lobpcg/lobpcg.py

lobpcg avatar Jan 03 '23 19:01 lobpcg

When I said "smallest" I meant smallest absolute value. Of course flipping the sign of A is trivial. It is extremely useful when dealing with ill-conditioned matrices in a variety of applications. When the matrix is indeed Hermitian, computing these directly instead of "naively" using an SVD solver, can be much faster or more accurate.

rmlarsen avatar Jan 03 '23 20:01 rmlarsen

How hard would it be to add a shift-invert mode to the current lobpcg solver?

jakevdp avatar Jan 03 '23 20:01 jakevdp

The way I read the paper, the beautiful thing about LOBPCG is its ability to solve for 1/lambda by a change of variables in the generalized eigenvalue problem. It may not be super efficient, but it's easy to use and presumably much more efficient than a naive shift-invert implementation(?)

rmlarsen avatar Jan 03 '23 20:01 rmlarsen

For my specific use case w/ Laplacians, there isn't too much ambiguity about the eigenvalues (large, small, etc. are all nonnegative). So flipping the order of RR like prev. mentioned is great. At least for me, the important thing is that preconditioning (which is already there, just commented out) can be valuable for large ill-conditioned problems.

choltz95 avatar Jan 03 '23 21:01 choltz95