Preconditioner for transposed krylov methods
I'm using lx.GMRES with a preconditioner and trying to understand why the forward solve converges fine but the gradient doesn't. Looking through the source, it seems like the solver.transpose throws away the preconditioner that's passed in options:
https://github.com/patrick-kidger/lineax/blob/66b7b5327a44e4b944a8ce9242773150e8a8d811/lineax/_solver/gmres.py#L416-L420
So the tangent solve here: https://github.com/patrick-kidger/lineax/blob/66b7b5327a44e4b944a8ce9242773150e8a8d811/lineax/_solve.py#L285-L294
Is done without preconditioning, which seems wrong. Am I misreading something?
It seems like there may also be something funky going on with stopping criteria on forward solves as well:
# random nonsymmetric poorly conditioned matrix
A = np.random.random((10,10))
A += np.diag(np.arange(10)**6)
b = np.random.random(A.shape[0])
A = lx.MatrixLinearOperator(jnp.array(A))
M = lx.MatrixLinearOperator(jnp.linalg.inv(A.matrix), # exact inverse, should only take 1 iteration
tags=lx.positive_semidefinite_tag # needed bc of #139
)
x = lx.linear_solve(A, b, solver=lx.GMRES(1e-8, 1e-8, max_steps=1, restart=1),
options={"preconditioner": M},
throw=True).value
This should converge in a single iteration, but it doesn't. If I set throw=False and compute the actual residual, it's still quite large (>1). Increasing max_steps to 2 brings the residual down to machine precision (so maybe just an off by 1 error in the iteration counting?), but it still says the solve failed if throw=True.
Hmm, throwing away the preconditioner is probably a mistake. Happy to take a PR on that! (And ideally a test 😄)
As for your second comment, I'm away from my laptop right now but I agree this seems not desirable. I think GMRES has been one of our less-loved children so it's possible that we've missed something in its implementation.
Sure I'll try to put together a PR for the preconditioner issue this week, and also try to fix some of the issues from #139
For the termination criteria, here's a more complete example:
# random nonsymmetric poorly conditioned matrix
A = np.random.random((10,10))
A += np.diag(np.arange(10)**6)
b = np.random.random(A.shape[0])
A = lx.MatrixLinearOperator(jnp.array(A))
M = lx.MatrixLinearOperator(jnp.linalg.inv(A.matrix), # exact inverse, should only take 1 iteration
tags=lx.positive_semidefinite_tag # needed bc of #139
)
print("======GMRES======")
for i in range(1, 7):
print("max_steps=", i)
out = lx.linear_solve(A, b, solver=lx.GMRES(1e-4, 1e-4, max_steps=i, restart=1),
options={"preconditioner": M},
throw=False)
x = out.value
print(f"norm(Ax-b)={np.linalg.norm(A.mv(x)-b):.3e}")
print("result=", lx.RESULTS[out.result])
print("======BiCGStab======")
for i in range(1, 7):
print("max_steps=", i)
out = lx.linear_solve(A, b, solver=lx.BiCGStab(1e-4, 1e-4, max_steps=i),
options={"preconditioner": M},
throw=False)
x = out.value
print(f"norm(Ax-b)={np.linalg.norm(A.mv(x)-b):.3e}")
print("result=", lx.RESULTS[out.result])
======GMRES======
max_steps= 1
norm(Ax-b)=1.870e+00
result= The maximum number of solver steps was reached. Try increasing `max_steps`.
max_steps= 2
norm(Ax-b)=1.087e-15
result= The maximum number of solver steps was reached. Try increasing `max_steps`.
max_steps= 3
norm(Ax-b)=3.925e-17
result= The maximum number of solver steps was reached. Try increasing `max_steps`.
max_steps= 4
norm(Ax-b)=3.925e-17
result=
max_steps= 5
norm(Ax-b)=3.925e-17
result=
max_steps= 6
norm(Ax-b)=3.925e-17
result=
======BiCGStab======
max_steps= 1
norm(Ax-b)=6.206e-17
result= The maximum number of solver steps was reached. Try increasing `max_steps`.
max_steps= 2
norm(Ax-b)=6.206e-17
result= The maximum number of solver steps was reached. Try increasing `max_steps`.
max_steps= 3
norm(Ax-b)=6.206e-17
result=
max_steps= 4
norm(Ax-b)=6.206e-17
result=
max_steps= 5
norm(Ax-b)=6.206e-17
result=
max_steps= 6
norm(Ax-b)=6.206e-17
result=
So GMRES takes 1 addiional iteration, which I think is because of this where the first iteration of GMRES is actually a no-op, not sure why this is? If we want to keep it that way I think we should at least modify the counting so that max_steps is the actual number of restarts.
However, even accounting for that it still seems like it takes more steps than necessary. I think another possible issue is the cauchy-like termination criteria: https://github.com/patrick-kidger/lineax/blob/66b7b5327a44e4b944a8ce9242773150e8a8d811/lineax/_solver/gmres.py#L132-L143
If we converge in 1 step, then diff will be much larger than y_scale, even though r << b_scale which will force it to take another step. The termination on diff << y_scale is not something I've seen in the literature or other libraries. It's probably not too big of an issue for things like CG/BiCGStab where an additional iteration is only 1-2 matvecs, but for GMRES an additional iteration could be 10s-100s of matvecs, so the extra cost is not negligible.
And on top of all that, I think there's also an issue here: https://github.com/patrick-kidger/lineax/blob/66b7b5327a44e4b944a8ce9242773150e8a8d811/lineax/_solver/gmres.py#L224-L227
where if it reaches the tolerance on the last step, it returns max_steps_reached instead of success
Not sure what the best approach is for these issues, as fixing some of them may be breaking changes.
So GMRES takes 1 addiional iteration, which I think is because of this where the first iteration of GMRES is actually a no-op, not sure why this is?
I didn't recall either so I messaged @packquickly who originally implemented this:
"""IIRC it’s a way of handling the dummy initialization of r0. We don’t want to introduce an extra matvec to compute the value of r0 earlier in the code to save compile time, so we set a dummy value of 0. However, each pass of _gmres_compute does compute the actual residual r, so in the first pass we don’t attempt to actually perform the GMRES algorithm on y and instead just use the loop to just compute r0."""
I imagine we could adjust things here to either do that in a different way or to adjust the counter by one.
The termination on
diff << y_scaleis not something I've seen in the literature or other libraries.
Right, this is us trying to standardize our termination criteria across ODEs/root-finding/linsolves/etc. The literature as a whole is pretty wildly inconsistent with each author just kind of making up their own choice.
That said we shouldn't let purity stand in the way of pragmatism. I'd be happy to hear a suggestion for how we might tweak this.
where if it reaches the tolerance on the last step, it returns max_steps_reached instead of success
Agreed, the termination criteria are a bit funky. C.f. also the discussion here:
https://github.com/patrick-kidger/lineax/pull/86#discussion_r2159488349
where if it reaches the tolerance on the last step, it returns max_steps_reached instead of success
Agreed, the termination criteria are a bit funky. C.f. also the discussion here:
https://github.com/patrick-kidger/lineax/pull/86#discussion_r2159488349
Recent changes to the behaviour of max_steps_reached: https://github.com/patrick-kidger/lineax/pull/129