Different Results on GPU and CPU for BiCGSTAB linear solve
Hello, I'm using lineax as a linear solver backend for some JAX-based Finite Element code I'm developing (to the lineax developers, thank you! lineax, equinox, etc are amazing IMO), and I'm getting some weird behavior when using a GPU vs. a CPU to solve a linear system using lineax.bicgstab. While there may be a bug hiding in my code, and I might need to use a better preconditioner, I am still confused by the difference in behavior.
Context
I'm trying to replicate a structural dynamics demo with some minor modifications (a slightly different time integration scheme, and Newton's method + BiCGSTAB are used). I am able to reproduce the results of the demo on a CPU and also on a GPU only if I am lucky - newton iterations would fail at random time steps. After providing more output from lineax, it's clear that a form of iterative breakdown has occurred in a linear solve, during nearly every incremental linear solve within a newton step, and eventually a NaN appears at a random timestep. Importantly, this lineax throw/exception it is not replicated while using a CPU to solve the same problem - no iterative solve breakdown occurs. Additionally, the solution I'm obtaining seems correct when compared to the analytical eigenmode/frequency on the CPU (and the GPU when I was able to skirt the numerical issues I'm seeing).
A Minimal Example
After iterating through dynamics-related updates (when I first saw the issue occur) on a previously working problem, I was able to reproduce a breakdown in the iterative BiCGSTAB solve by using a buggy material model:
def stress(u_grad):
I = np.eye(3)
F = u_grad + I
C = F.T @ F
E = (1/2) * (C - I)
# piola kirchoff stress - BUGGY on GPU only, and only when dynamics are non-negligible
# P = (lmbda * np.trace(E) * I + 2 * mu * E)
# correct piola kirchoff stress - NOT 'buggy' on GPU or CPU
P = F @ (lmbda * np.trace(E) * I + 2 * mu * E)
return P
While this might imply that there is a bug in my code, the confusing part is that lineax linear solves converge on a CPU (the newton solve does not converge due to the poorly defined material model). In contrast, running the code on a GPU sees the solve break down and occasionally result in a NaN. The attached logs show the difference in solver behavior. This is the same issue that I'm seeing in my larger structural dynamics code (when using a 'correct' material model).
Questions
Even if there is a small bug in my code somewhere, why is there a difference in CPU and GPU behavior during the linear solve? I have done everything I can to force all computations to use float64 (see 'debugging attempts`)
Other questions:
- Any tips for how to continue to debug this? I'm hitting my JAX debugging knowledge limits!
- Is this observed difference expected for ill-conditioned systems? Should I just be looking for a better preconditioner? (I'm using a jacobi preconditioner at the moment...)
Debugging Attempts
While trying to diagnose the problem, I attempted or noticed the following (nothing has solved the problem):
- setting
jax_default_matmul_precisionto'highest'to force GPU-based linear algebra backends to use 64bit precision (I am on an A100 fyi; I'm trying to use 'highest' per the discussion in JAX #19444): JAX #22557, tracked by JAX #18934 - GPU vs. CPU different behavior JAX #22382. There are some in-place updates that are made via
.at[]that I haven't looked into changing quite yet, but it seems like this has been fixed in XLA #19716.
Other debugging attempts/important mentions:
- 64 bit precision is enabled in JAX
-
jax.debug_nans- doesn't catch the issue. un-jitting the linear solve and using a breakpoint to re-run the problematic/NaN-inducing iteration, with the same LHS, RHS, and initial guess, does NOT reproduce the NaN
JAX version
note: I'm using lineax 0.0.8 and equinox 0.13.1
jax: 0.6.2
jaxlib: 0.6.2
numpy: 2.2.2
python: 3.10.12 (main, Aug 15 2025, 14:32:43) [GCC 11.4.0]
device info: NVIDIA A100 80GB PCIe-2, 2 local devices"
process_count: 1
platform: uname_result(system='Linux', node='tralfamadore', release='6.8.0-85-generic', version='#85~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Fri Sep 19 16:18:59 UTC 2', machine='x86_64')
$ nvidia-smi
Wed Oct 15 14:14:58 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.65.06 Driver Version: 580.65.06 CUDA Version: 13.0 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA A100 80GB PCIe Off | 00000000:17:00.0 Off | 0 |
| N/A 38C P0 72W / 300W | 445MiB / 81920MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 1 NVIDIA A100 80GB PCIe Off | 00000000:65:00.0 Off | 0 |
| N/A 37C P0 72W / 300W | 441MiB / 81920MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 2176 G /usr/lib/xorg/Xorg 4MiB |
| 0 N/A N/A 656991 C python 422MiB |
| 1 N/A N/A 2176 G /usr/lib/xorg/Xorg 4MiB |
| 1 N/A N/A 656991 C python 418MiB |
+-----------------------------------------------------------------------------------------+
Interesting! This is a fun one.
So it's great that you've already been able to isolate this to a particular set of inputs and to a particular iteration of the solve. (Though I'm not sure if you're referring to an iteation of the nonlinear solver or an iteration of the BiCGStab solver.)
You say that un-jit'ing the solve causes the issue to go away. This means this is likely something spooky on the part of the XLA compiler subtly misbehaving somehow.
I think probably the best way to diagnose this at this point would be to:
(a) construct a MWE which demonstrates a case in which CPU!=GPU, and which calls BiCGStab.compute directly. That is to say, cut out the lineax.linear_solve middleman. (The latter is simply a wrapper that applies autodiff rules to the solver, but that shouldn't matter at all here.) Make sure to still JIT the code, since that seems to matter.
(b) copy-paste the BiCGStab implementation so that your MWE calls your copy of the code;
(c) work through deleting pieces of the code, to make the MWE smaller, in each case checking to see if CPU!=GPU. Keep all of the code inside of a JIT'd region, since that seems to matter. It doesn't matter at all that this won't be a linear solver any more, we're just looking for a snippet of code on which CPU!=GPU.
(d) if we get lucky, then you'll be able to identify a small-ish chunk of code that demonstrates this issue, and if we're even luckier then this code will have no dependencies other than JAX.
(e) open an issue on the JAX GitHub with that snippet of code.
On your second question about whether this is expected: in some edge cases this behaviour is expected/reasonable. CPUs and GPUs will perform slightly different operations at the floating point level, in particular around nontrivial linear algebra operations. And in some cases, small fluctuations there can propagate to large changes in the system. Hard to say if this is considered acceptable or not without a concise MWE.
I hope that helps!
Thanks for the quick response, and I appreciate you outlining those steps - I'll give them a shot and report back with whatever I find!
After some digging, it seems like this is a JAX issue. Some updates:
I was able to get to about step (d), where I was able to find the first instance where CPU!=GPU on line 145 of lineax's bicgstab:
x = preconditioner.mv(p_new)
v_new = operator.mv(x)
However, as explained in JAX issue 19444, the observed difference was likely due to non-associativity of certain floating point operations (addition that occurs in mat-vec multiplication in our case, I think). The difference between v_new on the GPU vs. CPU (on the order of $10^{-15}$) was roughly what I would expect for this.
I spent some time investigating further, and I observed the following:
- lineax's and JAX's
BiCGStabmethods produce exactly the same results on GPU when the preconditioner is defined as aMatrixLinearOperator- so this is likely not a lineax issue! - behavior differs slightly (a breakdown is still observed, but the solver is able to progress a bit further pretty consistently) when the preconditioner is defined as a
FunctionLinearOperatorinstead of aMatrixLinearOperator. This seems to confirm that the ordering of operations may slightly affect results/convergence of BiCGStab. Didn't dive deep into this, just thought it was interesting, and it helped guide me towards a potential solution to this:
After digging and looking through another GPU-based implementation of BiCGStab in AMGX, I noticed a minor difference in the implementation of a subtraction/scaling of vectors in the 3rd line of the algorithm. A small update to this line in jax.scipy.sparse.linalg.bicgstab seems to fix the iterative breakdown observed on the GPU; I plan to open a JAX issue which will provide more detail.
I think this issue can be closed for now as the agreement between lineax and JAX BiCGStab methods suggests that this is a JAX-related issue (or maybe still an edge case - I've only tested this on one problem!). @patrick-kidger those steps helped a lot, thanks again! If anything in the JAX repo is updated in relation to this issue, I plan to submit a Lineax PR that reflects those changes.
Awesome! I'm glad we have an answer.
I'd be very happy to take a PR making this tweak if/when you feel it's appropriate. (I can see there are some ongoing discussions in the other thread.)