jax icon indicating copy to clipboard operation
jax copied to clipboard

Transposed preconditioned GMRES

Open f0uriest opened this issue 6 months ago • 0 comments

Description

I'm using jax.scipy.sparse.linalg.gmres with a preconditioner to solve some linear systems, and was trying to figure out why the forward solve would converge just fine but the gradient would not.

Looking through the code it looks like _solve closes over the preconditioner, so that the same preconditioner is used for the forward and backward pass: https://github.com/jax-ml/jax/blob/e04cc283d84a2df3ab0baa9c37f19f90600e11c1/jax/_src/scipy/sparse/linalg.py#L698-L700

Assuming the preconditioner satisfies $M A \sim I$, there's no reason that $M A^T \sim I$ unless $A$ is symmetric (which it usually isn't for GMRES).

I think the correct thing to do would be to define a separate _transpose_solve that uses the transpose of the preconditioner, since if $M A \sim I$ then $M^T A^T \sim I$ as well.

(also note that I think this is an issue for jax.scipy.sparse.linalg.bicgstab as well. cg is fine because $A$ is symmetric)

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.6.0 jaxlib: 0.6.0 numpy: 1.26.4 python: 3.12.9 | packaged by Anaconda, Inc. | (main, Feb 6 2025, 18:56:27) [GCC 11.2.0] device info: cpu-8, 8 local devices" process_count: 1 platform: uname_result(system='Linux', node='Nautilus', release='6.11.0-26-generic', version='#26~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Thu Apr 17 19:20:47 UTC 2', machine='x86_64')

f0uriest avatar Jun 13 '25 02:06 f0uriest