cache_miss caused by `lx.linear_solve`
Hi there,
I have discovered that the TridiagonalLinearOperator causes a cache miss to occur under jit, which leads to a silent recompilation. This is similar to this jax issue, however when I did testing with the flags jax_log_compiles=True and jax_explain_cache_misses=True, nothing showed up (a truly pathological silent error) and I'm still not entirely sure why.
I originally noticed this in an interpax function which was slowing down my code by ~an order of magnitude (I have an issue on interpax going into more detail), but after delving into the interpax source code I found it was the lineax.TridiagonalLinearOperator that was causing the cache miss (here).
I've diagnosed this by code profiling with Perfetto, where you can actually see the cache_miss function being called each time.
Here is a minimal reproducible example purely with lineax. I am using lineax=0.0.8, jax=0.7.0, and equinox=0.13.0.
import lineax as lx
import equinox as eqx
import jax
from jax import numpy as jnp
# flags for debugging
jax.config.update("jax_log_compiles", True)
jax.config.update("jax_explain_cache_misses", True)
@eqx.filter_jit
@eqx.debug.assert_max_traces(max_traces=1)
def f(diag, lower_diag, upper_diag, b):
A = lx.TridiagonalLinearOperator(diag, lower_diag, upper_diag)
solve = lambda b: lx.linear_solve(A, b, lx.Tridiagonal()).value
fx = jnp.vectorize(solve, signature="(n)->(n)")(b.T).T
return fx
# setting up inputs
n = 5
diag = jnp.ones(n)
lower_diag = jnp.zeros(n - 1)
upper_diag = jnp.zeros(n - 1)
b = jnp.linspace(0, 1, n)
# compiling
f(diag, lower_diag, upper_diag, b)
print("Compilation done.")
# running five times and tracing with perfetto
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
for i in range(5):
f(diag, lower_diag, upper_diag, b).block_until_ready()
I've attached the output of this script at the end. In it are lots of logs from the "jax_log_compiles" flag but none of them occur under jit, and the "jax_explain_cache_misses" is completely silent.
This is a screenshot of the Perfetto trace. I've also attached the Perfetto trace file here.
Do you have any idea what could be causing this?
Thanks, Max
Script output:
WARNING:2025-08-19 18:15:49,152:jax._src.dispatch:198: Finished tracing + transforming convert_element_type for pjit in 0.000162125 sec
WARNING:2025-08-19 18:15:49,162:jax._src.interpreters.pxla:1861: Compiling jit(convert_element_type) with global shapes and types [ShapedArray(float32[])]. Argument mapping: (UnspecifiedValue,).
WARNING:2025-08-19 18:15:49,200:jax._src.dispatch:198: Finished jaxpr to MLIR module conversion jit(convert_element_type) in 0.037410975 sec
WARNING:2025-08-19 18:15:49,252:jax._src.dispatch:198: Finished XLA compilation of jit(convert_element_type) in 0.052269220 sec
WARNING:2025-08-19 18:15:49,255:jax._src.dispatch:198: Finished tracing + transforming broadcast_in_dim for pjit in 0.000164032 sec
WARNING:2025-08-19 18:15:49,255:jax._src.interpreters.pxla:1861: Compiling jit(broadcast_in_dim) with global shapes and types [ShapedArray(float32[])]. Argument mapping: (UnspecifiedValue,).
WARNING:2025-08-19 18:15:49,259:jax._src.dispatch:198: Finished jaxpr to MLIR module conversion jit(broadcast_in_dim) in 0.003476143 sec
WARNING:2025-08-19 18:15:49,303:jax._src.dispatch:198: Finished XLA compilation of jit(broadcast_in_dim) in 0.043696880 sec
WARNING:2025-08-19 18:15:49,303:jax._src.dispatch:198: Finished tracing + transforming broadcast_in_dim for pjit in 0.000151873 sec
WARNING:2025-08-19 18:15:49,304:jax._src.interpreters.pxla:1861: Compiling jit(broadcast_in_dim) with global shapes and types [ShapedArray(float32[])]. Argument mapping: (UnspecifiedValue,).
WARNING:2025-08-19 18:15:49,306:jax._src.dispatch:198: Finished jaxpr to MLIR module conversion jit(broadcast_in_dim) in 0.001924992 sec
WARNING:2025-08-19 18:15:49,312:jax._src.dispatch:198: Finished XLA compilation of jit(broadcast_in_dim) in 0.006244898 sec
WARNING:2025-08-19 18:15:49,313:jax._src.dispatch:198: Finished tracing + transforming subtract for pjit in 0.000132084 sec
WARNING:2025-08-19 18:15:49,313:jax._src.dispatch:198: Finished tracing + transforming true_divide for pjit in 0.000101089 sec
WARNING:2025-08-19 18:15:49,314:jax._src.dispatch:198: Finished tracing + transforming true_divide for pjit in 0.000108004 sec
WARNING:2025-08-19 18:15:49,314:jax._src.dispatch:198: Finished tracing + transforming subtract for pjit in 0.000113964 sec
WARNING:2025-08-19 18:15:49,314:jax._src.dispatch:198: Finished tracing + transforming multiply for pjit in 0.000083208 sec
WARNING:2025-08-19 18:15:49,314:jax._src.dispatch:198: Finished tracing + transforming add for pjit in 0.000087023 sec
WARNING:2025-08-19 18:15:49,315:jax._src.dispatch:198: Finished tracing + transforming _linspace for pjit in 0.002079010 sec
WARNING:2025-08-19 18:15:49,315:jax._src.interpreters.pxla:1861: Compiling jit(_linspace) with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: (UnspecifiedValue, UnspecifiedValue).
WARNING:2025-08-19 18:15:49,319:jax._src.dispatch:198: Finished jaxpr to MLIR module conversion jit(_linspace) in 0.004170895 sec
WARNING:2025-08-19 18:15:49,341:jax._src.dispatch:198: Finished XLA compilation of jit(_linspace) in 0.021125078 sec
WARNING:2025-08-19 18:15:49,342:jax._src.dispatch:198: Finished tracing + transforming _fn for pjit in 0.000157833 sec
WARNING:2025-08-19 18:15:49,342:jax._src.dispatch:198: Finished tracing + transforming _fn for pjit in 0.000091791 sec
WARNING:2025-08-19 18:15:49,342:jax._src.dispatch:198: Finished tracing + transforming _fn for pjit in 0.000079870 sec
WARNING:2025-08-19 18:15:49,343:jax._src.dispatch:198: Finished tracing + transforming _fn for pjit in 0.000072956 sec
WARNING:2025-08-19 18:15:49,344:jax._src.dispatch:198: Finished tracing + transforming _squeeze for pjit in 0.000031948 sec
WARNING:2025-08-19 18:15:49,344:jax._src.dispatch:198: Finished tracing + transforming <lambda> for pjit in 0.000031948 sec
WARNING:2025-08-19 18:15:49,346:jax._src.dispatch:198: Finished tracing + transforming <lambda> for pjit in 0.000039816 sec
WARNING:2025-08-19 18:15:49,346:jax._src.dispatch:198: Finished tracing + transforming <lambda> for pjit in 0.000141859 sec
WARNING:2025-08-19 18:15:49,347:jax._src.dispatch:198: Finished tracing + transforming greater for pjit in 0.000206947 sec
WARNING:2025-08-19 18:15:49,347:jax._src.dispatch:198: Finished tracing + transforming subtract for pjit in 0.000099897 sec
WARNING:2025-08-19 18:15:49,347:jax._src.dispatch:198: Finished tracing + transforming _broadcast_arrays for pjit in 0.000038147 sec
WARNING:2025-08-19 18:15:49,348:jax._src.dispatch:198: Finished tracing + transforming _where for pjit in 0.000409126 sec
WARNING:2025-08-19 18:15:49,348:jax._src.dispatch:198: Finished tracing + transforming less for pjit in 0.000158072 sec
WARNING:2025-08-19 18:15:49,349:jax._src.dispatch:198: Finished tracing + transforming add for pjit in 0.000149965 sec
WARNING:2025-08-19 18:15:49,350:jax._src.dispatch:198: Finished tracing + transforming multiply for pjit in 0.000894070 sec
WARNING:2025-08-19 18:15:49,350:jax._src.dispatch:198: Finished tracing + transforming add for pjit in 0.000102282 sec
WARNING:2025-08-19 18:15:49,351:jax._src.dispatch:198: Finished tracing + transforming multiply for pjit in 0.000064850 sec
WARNING:2025-08-19 18:15:49,353:jax._src.dispatch:198: Finished tracing + transforming isfinite for pjit in 0.000086069 sec
WARNING:2025-08-19 18:15:49,354:jax._src.dispatch:198: Finished tracing + transforming invert for pjit in 0.000070810 sec
WARNING:2025-08-19 18:15:49,354:jax._src.dispatch:198: Finished tracing + transforming _reduce_any for pjit in 0.000190973 sec
WARNING:2025-08-19 18:15:49,354:jax._src.dispatch:198: Finished tracing + transforming _reduce_any for pjit in 0.000198126 sec
WARNING:2025-08-19 18:15:49,355:jax._src.dispatch:198: Finished tracing + transforming convert_element_type for pjit in 0.000075817 sec
WARNING:2025-08-19 18:15:49,355:jax._src.interpreters.pxla:1861: Compiling jit(convert_element_type) with global shapes and types [ShapedArray(bool[])]. Argument mapping: (UnspecifiedValue,).
WARNING:2025-08-19 18:15:49,357:jax._src.dispatch:198: Finished jaxpr to MLIR module conversion jit(convert_element_type) in 0.001612186 sec
WARNING:2025-08-19 18:15:49,362:jax._src.dispatch:198: Finished XLA compilation of jit(convert_element_type) in 0.005482912 sec
WARNING:2025-08-19 18:15:49,363:jax._src.dispatch:198: Finished tracing + transforming bitwise_and for pjit in 0.000178814 sec
WARNING:2025-08-19 18:15:49,363:jax._src.dispatch:198: Finished tracing + transforming _broadcast_arrays for pjit in 0.000038147 sec
WARNING:2025-08-19 18:15:49,364:jax._src.dispatch:198: Finished tracing + transforming _where for pjit in 0.000247240 sec
WARNING:2025-08-19 18:15:49,365:jax._src.dispatch:198: Finished tracing + transforming equal for pjit in 0.000129938 sec
WARNING:2025-08-19 18:15:49,368:jax._src.dispatch:198: Finished tracing + transforming not_equal for pjit in 0.000437021 sec
WARNING:2025-08-19 18:15:49,407:jax._src.dispatch:198: Finished tracing + transforming <lambda> for pjit in 0.000070810 sec
WARNING:2025-08-19 18:15:49,408:jax._src.dispatch:198: Finished tracing + transforming branched_error_if_impl for pjit in 0.038447857 sec
WARNING:2025-08-19 18:15:49,408:jax._src.dispatch:198: Finished tracing + transforming _fn for pjit in 0.062253237 sec
WARNING:2025-08-19 18:15:49,408:jax._src.dispatch:198: Finished tracing + transforming linear_solve for pjit in 0.064262867 sec
WARNING:2025-08-19 18:15:49,409:jax._src.dispatch:198: Finished tracing + transforming f for pjit in 0.067122936 sec
WARNING:2025-08-19 18:15:49,409:jax._src.interpreters.pxla:1861: Compiling jit(f) with global shapes and types [ShapedArray(float32[5]), ShapedArray(float32[4]), ShapedArray(float32[4]), ShapedArray(float32[5])]. Argument mapping: (UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue).
WARNING:2025-08-19 18:15:49,412:jax._src.dispatch:198: Finished tracing + transforming <lambda> for pjit in 0.000206947 sec
WARNING:2025-08-19 18:15:49,412:jax._src.dispatch:198: Finished tracing + transforming <lambda> for pjit in 0.000045061 sec
WARNING:2025-08-19 18:15:49,413:jax._src.dispatch:198: Finished tracing + transforming greater for pjit in 0.000105143 sec
WARNING:2025-08-19 18:15:49,413:jax._src.dispatch:198: Finished tracing + transforming subtract for pjit in 0.000081062 sec
WARNING:2025-08-19 18:15:49,413:jax._src.dispatch:198: Finished tracing + transforming _broadcast_arrays for pjit in 0.000029325 sec
WARNING:2025-08-19 18:15:49,413:jax._src.dispatch:198: Finished tracing + transforming _where for pjit in 0.000193119 sec
WARNING:2025-08-19 18:15:49,414:jax._src.dispatch:198: Finished tracing + transforming less for pjit in 0.000079870 sec
WARNING:2025-08-19 18:15:49,414:jax._src.dispatch:198: Finished tracing + transforming add for pjit in 0.000105858 sec
WARNING:2025-08-19 18:15:49,415:jax._src.dispatch:198: Finished tracing + transforming multiply for pjit in 0.000109911 sec
WARNING:2025-08-19 18:15:49,415:jax._src.dispatch:198: Finished tracing + transforming subtract for pjit in 0.000125885 sec
WARNING:2025-08-19 18:15:49,415:jax._src.dispatch:198: Finished tracing + transforming true_divide for pjit in 0.000073195 sec
WARNING:2025-08-19 18:15:49,415:jax._src.dispatch:198: Finished tracing + transforming add for pjit in 0.000076056 sec
WARNING:2025-08-19 18:15:49,417:jax._src.dispatch:198: Finished tracing + transforming multiply for pjit in 0.000058889 sec
WARNING:2025-08-19 18:15:49,418:jax._src.dispatch:198: Finished tracing + transforming isfinite for pjit in 0.000057936 sec
WARNING:2025-08-19 18:15:49,418:jax._src.dispatch:198: Finished tracing + transforming invert for pjit in 0.000138998 sec
WARNING:2025-08-19 18:15:49,419:jax._src.dispatch:198: Finished tracing + transforming _reduce_any for pjit in 0.000274181 sec
WARNING:2025-08-19 18:15:49,419:jax._src.dispatch:198: Finished tracing + transforming _reduce_any for pjit in 0.000200033 sec
WARNING:2025-08-19 18:15:49,419:jax._src.dispatch:198: Finished tracing + transforming convert_element_type for pjit in 0.000107050 sec
WARNING:2025-08-19 18:15:49,420:jax._src.interpreters.pxla:1861: Compiling jit(convert_element_type) with global shapes and types [ShapedArray(bool[])]. Argument mapping: (UnspecifiedValue,).
WARNING:2025-08-19 18:15:49,421:jax._src.dispatch:198: Finished jaxpr to MLIR module conversion jit(convert_element_type) in 0.001550674 sec
WARNING:2025-08-19 18:15:49,425:jax._src.dispatch:198: Finished XLA compilation of jit(convert_element_type) in 0.003887892 sec
WARNING:2025-08-19 18:15:49,426:jax._src.dispatch:198: Finished tracing + transforming bitwise_and for pjit in 0.000257015 sec
WARNING:2025-08-19 18:15:49,426:jax._src.dispatch:198: Finished tracing + transforming _broadcast_arrays for pjit in 0.000034094 sec
WARNING:2025-08-19 18:15:49,426:jax._src.dispatch:198: Finished tracing + transforming _where for pjit in 0.000252962 sec
WARNING:2025-08-19 18:15:49,427:jax._src.dispatch:198: Finished tracing + transforming equal for pjit in 0.000099182 sec
WARNING:2025-08-19 18:15:49,427:jax._src.dispatch:198: Finished tracing + transforming not_equal for pjit in 0.000099182 sec
WARNING:2025-08-19 18:15:49,434:jax._src.dispatch:198: Finished tracing + transforming <lambda> for pjit in 0.000093937 sec
WARNING:2025-08-19 18:15:49,434:jax._src.dispatch:198: Finished tracing + transforming branched_error_if_impl for pjit in 0.006947041 sec
WARNING:2025-08-19 18:15:49,445:jax._src.dispatch:198: Finished tracing + transforming _reduce_any for pjit in 0.000211954 sec
WARNING:2025-08-19 18:15:49,446:jax._src.dispatch:198: Finished tracing + transforming _reduce_max for pjit in 0.000173092 sec
WARNING:2025-08-19 18:15:49,449:jax._src.dispatch:198: Finished jaxpr to MLIR module conversion jit(f) in 0.040294886 sec
WARNING:2025-08-19 18:15:49,498:jax._src.dispatch:198: Finished XLA compilation of jit(f) in 0.047907829 sec
Compilation done.
2025-08-19 18:15:49.499140: E external/xla/xla/python/profiler/internal/python_hooks.cc:416] Can't import tensorflow.python.profiler.trace
2025-08-19 18:15:49.501289: E external/xla/xla/python/profiler/internal/python_hooks.cc:416] Can't import tensorflow.python.profiler.trace
Open URL in browser: https://ui.perfetto.dev/#!/?url=http://127.0.0.1:9001/perfetto_trace.json.gz
127.0.0.1 - - [19/Aug/2025 18:16:00] code 501, message Unsupported method ('OPTIONS')
127.0.0.1 - - [19/Aug/2025 18:16:00] "OPTIONS /status HTTP/1.1" 501 -
127.0.0.1 - - [19/Aug/2025 18:16:00] code 404, message File not found
127.0.0.1 - - [19/Aug/2025 18:16:00] "POST /status HTTP/1.1" 404 -
127.0.0.1 - - [19/Aug/2025 18:16:00] code 501, message Unsupported method ('OPTIONS')
127.0.0.1 - - [19/Aug/2025 18:16:00] "OPTIONS /perfetto_trace.json.gz HTTP/1.1" 501 -
127.0.0.1 - - [19/Aug/2025 18:16:00] "GET /perfetto_trace.json.gz HTTP/1.1" 200 -
Thanks for the comment, I have not studied the error message in much detail but the Can't import tensorflow.python.profiler.trace is suspect. Note that Jax guidance for memory has been updated in June, they now recommend XProf. Does it run fine without trying to Perfetto trace?
Otherwise, this could be due to 0.7.0 which has a number of memory related bugs. Are you running on CPU or GPU? Perhaps try update to the jax nightly release or push back to the last stable 0.6.xx to see if this persists?
Alternatively, we have updated the tridiagonal solve on the lineax main branch so if you checkout that version there's a chance it might play nicer with the new version of jax.
Thanks for the quick response @jpbrodrick89!
Rerunning the same test with jax==0.6.2, jax==0.5.3, and jax==0.4.38 gives ~~identical traces with~~ a cache miss and recompile present. I also tried installing lineax main branch from source and unfortunately the issue is still there. I've confirmed this using XProf as well now.
These traces were run on my CPU locally, but I have also tested on two different GPUs and I get the same issue.
I believe the Can't import tensorflow.python.profiler.trace would just go away if I had tensorflow installed in my environment, which is relatively annoying to do in the current setup due to version requirements.
I can definitely tell (and feel) the silent recompile is occurring even when not tracing. The issue in my code was that I was using the "cubic2" interpolation method from interpax, which uses the lineax tridiagonal solve under the hood. When simply switching to the "cubic" interpolation method (which does not use lineax), my jitted functions sped up by a factor of ~30. Additionally, when using the "cubic2" method, moving to an A100 GPU saw only a ~10% speedup of my jitted functions (as opposed to the expected ~10000% speed up) which would be consistent with a silent recompilation falling back to the CPU. Using the "cubic" method got around this and I saw the expected speedup.
In debugging the interpax function, the only way I was able to get the cache_miss to disappear from the trace was to remove the lineax call.
Hi, me again, perhaps with some bad news...
I did some digging into the lineax source code and have been able to narrow down the source of the cache miss a little bit further. Rather unfortunately, it is in lx.linear_solve, which, if I am right, will obviously have much broader consequences than just the Tridiagonal solver.
MRE, which I've adapted from this lineax docs example:
import lineax as lx
import equinox as eqx
import jax
from jax import random as jr
print("JAX version:", jax.__version__)
print("Equinox version:", eqx.__version__)
print("Lineax version:", lx.__version__)
# setting up inputs
matrix = jr.normal(jr.PRNGKey(0), (3, 3))
vector = jr.normal(jr.PRNGKey(1), (3,))
operator = lx.MatrixLinearOperator(matrix)
# compiling
lx.linear_solve(operator, vector) # is already jitted in lineax function def'n
print("Compilation done.")
# running five times and tracing with XProf
with jax.profiler.trace("/tmp/jax-trace"):
for i in range(5):
lx.linear_solve(operator, vector)
> JAX version: 0.6.2
> Equinox version: 0.13.0
> Lineax version: 0.0.8
> Compilation done.
> 2025-08-20 16:49:29.620495: E external/xla/xla/python/profiler/internal/python_hooks.cc:412] Can't import tensorflow.python.profiler.trace
> 2025-08-20 16:49:29.623940: E external/xla/xla/python/profiler/internal/python_hooks.cc:412] Can't import tensorflow.python.profiler.trace
And here is the trace showing a cache miss and recompile:
I have narrowed it down to this line in lineax._solve.py, which I believe is actually where the computation happens by calling the linear solve primitive linear_solve_p. It is at this point I fear I am asymptoting to the limit of where my jax/equinox knowledge is capable of taking me.
To justify why I believe this primitive is the issue, I copied the lx.linear_solve verbatim and added a return statement just before this primitive is called (see below). When I trace this version of the function, there is no cache miss! And no recompile!
import jax
from jax import random as jr, lax, tree_util as jtu
import equinox as eqx
import lineax as lx
from lineax._custom_types import sentinel
from lineax._misc import inexact_asarray, strip_weak_dtype
from lineax._operator import (
AbstractLinearOperator,
IdentityLinearOperator,
)
from lineax._solution import RESULTS, Solution
from jaxtyping import ArrayLike, PyTree
from lineax._solve import AutoLinearSolver, AbstractLinearSolver, linear_solve_p
from typing import Any
import equinox.internal as eqxi
@eqx.filter_jit
def linear_solve(
operator: AbstractLinearOperator,
vector: PyTree[ArrayLike],
solver: AbstractLinearSolver = AutoLinearSolver(well_posed=True),
*,
options: dict[str, Any] | None = None,
state: PyTree[Any] = sentinel,
throw: bool = True,
) -> Solution:
"""
...
"""
if eqx.is_array(operator):
raise ValueError(
"`lineax.linear_solve(operator=...)` should be an "
"`AbstractLinearOperator`, not a raw JAX array. If you are trying to pass "
"a matrix then this should be passed as "
"`lineax.MatrixLinearOperator(matrix)`."
)
if options is None:
options = {}
vector = jtu.tree_map(inexact_asarray, vector)
vector_struct = strip_weak_dtype(jax.eval_shape(lambda: vector))
operator_out_structure = strip_weak_dtype(operator.out_structure())
# `is` to handle tracers
if eqx.tree_equal(vector_struct, operator_out_structure) is not True:
raise ValueError(
"Vector and operator structures do not match. Got a vector with structure "
f"{vector_struct} and an operator with out-structure "
f"{operator_out_structure}"
)
if isinstance(operator, IdentityLinearOperator):
return Solution(
value=vector,
result=RESULTS.successful,
state=state,
stats={},
)
if state == sentinel:
state = solver.init(operator, options)
dynamic_state, static_state = eqx.partition(state, eqx.is_array)
dynamic_state = lax.stop_gradient(dynamic_state)
state = eqx.combine(dynamic_state, static_state)
state = eqxi.nondifferentiable(state, name="`lineax.linear_solve(..., state=...)`")
options = eqxi.nondifferentiable(
options, name="`lineax.linear_solve(..., options=...)`"
)
solver = eqxi.nondifferentiable(
solver, name="`lineax.linear_solve(..., solver=...)`"
)
return solver
# solution, result, stats = eqxi.filter_primitive_bind(
# linear_solve_p, operator, state, vector, options, solver, throw
# )
# # TODO: prevent forward-mode autodiff through stats
# stats = eqxi.nondifferentiable_backward(stats)
# return Solution(value=solution, result=result, state=state, stats=stats)
matrix = jr.normal(jr.PRNGKey(0), (3, 3))
vector = jr.normal(jr.PRNGKey(1), (3,))
operator = lx.MatrixLinearOperator(matrix)
# compile the function
linear_solve(operator, vector)
# running the function multiple times
with jax.profiler.trace("/tmp/jax-trace"):
for i in range(5):
linear_solve(operator, vector)
Can you check if this is fixed on https://github.com/johannahaffner/equinox/tree/main @maxecharles?
I can't reproduce the XProf profiles right now (somehow only shows me empty pages on localhost), and I've given up fiddling around with this, ~but I did run some tests and it seems that recompilation does get triggered without the fix, and does not occur with the fix.~
Thanks for looking into this @johannahaffner.
Unfortunately I did not see a fix to the cache miss issue. I was able to reproduce the other error and confirm that your fork fixed that, and the cache miss still appeared in the same script.
However, you said that eqx.debug.assert_max_traces was reporting a recompile (jax==0.7.0)! I have not been able to reproduce this, with or without the fix in your fork. For me, it's just sailing through and not triggering anything, while still reporting a cache miss and compile in the trace.
import lineax as lx
import equinox as eqx
import jax
from jax import random as jr
print("JAX version:", jax.__version__)
print("Equinox version:", eqx.__version__)
print("Lineax version:", lx.__version__)
@eqx.filter_jit
@eqx.debug.assert_max_traces(max_traces=1)
def f(operator, vector):
return lx.linear_solve(operator, vector)
# setting up inputs
matrix = jr.normal(jr.PRNGKey(0), (3, 3))
vector = jr.normal(jr.PRNGKey(1), (3,))
operator = lx.MatrixLinearOperator(matrix)
# compiling
f(operator, vector)
print("Compilation done.")
# running under jit
for i in range(5):
f(operator, vector)
print("Done.")
> JAX version: 0.7.0
> Equinox version: 0.13.0
> Lineax version: 0.0.8
> Compilation done.
> Done.
How were you able to trigger assert_max_traces?
Re: XProf, just checking your issue isn't that you didn't select "Trace Viewer" from the Tools dropdown menu – it took me longer than I care to admit to find that. I definitely have found Perfetto to be easier to work with.
How were you able to trigger
assert_max_traces?
I think I missed something there! With your example I also cannot reproduce it 🙈 perhaps I overlooked an inconsistent argument. I'm a little stumped, I ran this to check that adding the "__weakref__" does not break anything here.
Re: XProf, just checking your issue isn't that you didn't select "Trace Viewer" from the Tools dropdown menu – it took me longer than I care to admit to find that. I definitely have found Perfetto to be easier to work with.
I don't see a Tools dropdown menu, and the dropdown menu I can see does not have this option.
Hmm, I'm not sure I completely understand this one. Running your initial script, then f appears to only be compiled once. To be precise, the following program prints 'Compiling' only once:
import lineax as lx
import equinox as eqx
import jax
from jax import numpy as jnp
@eqx.filter_jit
def f(diag, lower_diag, upper_diag, b):
print("Compiling")
A = lx.TridiagonalLinearOperator(diag, lower_diag, upper_diag)
solve = lambda b: lx.linear_solve(A, b, lx.Tridiagonal()).value
fx = jnp.vectorize(solve, signature="(n)->(n)")(b.T).T
return fx
# setting up inputs
n = 5
diag = jnp.ones(n)
lower_diag = jnp.zeros(n - 1)
upper_diag = jnp.zeros(n - 1)
b = jnp.linspace(0, 1, n)
# compiling
f(diag, lower_diag, upper_diag, b)
for i in range(5):
f(diag, lower_diag, upper_diag, b).block_until_ready()
IIUC, is your point that the JIT dispatch is hitting a slow path rather than a fastpath, i.e. adding a few milliseconds per call?
Hi @patrick-kidger.
Yes... sort of. I think, at least in the case of my code, it's quite a bit more than just a few milliseconds... sometimes -- I've just discovered a change in the behaviour based on jax version, but maybe I will take a step back and better explain how I got here first. Sorry in advance for the massive incoming information dump, but I think it might be important.
In my code, I am basically just doing gradient descent on an equinox model, which uses interpax interpolation multiple times in the model evaluation. I noticed that when I ran the code on GPU, it was barely any faster than on CPU, which prompted me to start doing the code profiling I have been showing.
Below is a trace similar to what I initially saw, running on the GPU (2080Ti). This is just a loop running the model evaluation function under jit five times in a row. There is obviously something wrong here — you can see that about 80% of the function evaluation time is spent on this compile call... on the CPU! Then it moves to the GPU and spends a small amount of time computing and then the vicious cycle repeats. Yuck. This was obviously the chokepoint that was slowing down my code, and I assumed this compile was being triggered by the cache_miss seen three lines above.
> JAX version 0.6.0
>
> Compiling model...
> Model compiled in 27.30 seconds
>
> Timing model...
> Model 0 took 0.92 seconds
> Model 1 took 0.89 seconds
> Model 2 took 0.90 seconds
> Model 3 took 0.89 seconds
> Model 4 took 0.90 seconds
So, what was causing this cache_miss? This is the start of the trail of breadcrumbs that led me to the linear_solve_p primitive.
Through code profiling (adding return statements progressively earlier in the function until the cache_miss goes away) I figured out in my function the cache_miss was caused by the interpax interpolation. When I changed the interpolation method from "cubic2" to "cubic", (which does not use lineax), below is the trace I saw. There is no cache_miss and no compile calls. It only runs on the CPU for pre-processing and then just goes brrrrr on the GPU pretty much the whole time (I assume normal jit behaviour). By avoiding this cache_miss I am seeing an order of magnitude speedup under jit, and interestingly the first compile time was much shorter too (not sure why that would be tbh).
> JAX version 0.6.0
> Compiling model...
> Model compiled in 6.08 seconds
>
> Timing model...
> Model 0 took 0.13 seconds
> Model 1 took 0.10 seconds
> Model 2 took 0.10 seconds
> Model 3 took 0.10 seconds
> Model 4 took 0.10 seconds
I really did need that specific "cubic2" interpolation method, so I continued following the trail of breadcrumbs to find the source of the cache_miss. This led me here. My hope is there is one line of code that could be altered which would avoid the cache_miss and everyone is happy, which seems to be the case in this similar jax issue.
Here is a pure lineax version, similar to my previous posts but this time run on the GPU. Again more than 90% of this time is spent silently recompiling on the CPU!
import lineax as lx
import equinox as eqx
import jax
from jax import random as jr, numpy as jnp
from time import time
print("JAX version:", jax.__version__)
print("Equinox version:", eqx.__version__)
print("Lineax version:", lx.__version__)
@eqx.filter_jit
@eqx.debug.assert_max_traces(max_traces=1)
def f(operator, vector):
print("COMPILING")
return lx.linear_solve(operator, vector).value
# setting up inputs
matrix = jr.normal(jr.PRNGKey(0), (3, 3))
vector = jr.normal(jr.PRNGKey(1), (3,))
operator = lx.MatrixLinearOperator(matrix)
# compiling
f(operator, vector).block_until_ready()
# running five times and tracing with perfetto
with jax.profiler.trace("/tmp/profile-data", create_perfetto_link=True):
for i in range(5):
f(operator, vector).block_until_ready()
> JAX version: 0.6.1
> Equinox version: 0.13.0
> Lineax version: 0.0.8
> COMPILING
> 2025-08-22 17:53:37.728192: E external/xla/xla/python/profiler/internal/python_hooks.cc:412] Can't import tensorflow.python.profiler.trace
> 2025-08-22 17:53:37.751022: E external/xla/xla/python/profiler/internal/python_hooks.cc:412] Can't import tensorflow.python.profiler.trace
> Open URL in browser: https://ui.perfetto.dev/#!/?url=http://127.0.0.1:9001/perfetto_trace.json.gz
> 127.0.0.1 - - [22/Aug/2025 17:53:43] code 404, message File not found
> 127.0.0.1 - - [22/Aug/2025 17:53:43] "POST /status HTTP/1.1" 404 -
> 127.0.0.1 - - [22/Aug/2025 17:53:43] code 501, message Unsupported method ('OPTIONS')
> 127.0.0.1 - - [22/Aug/2025 17:53:43] "OPTIONS /status HTTP/1.1" 501 -
> 127.0.0.1 - - [22/Aug/2025 17:53:43] "GET /perfetto_trace.json.gz HTTP/1.1" 200 -
Ok. Re: the different behaviour based on jax version I alluded to.
I noticed if I run the identical script as the one in my last comment but with jax==0.6.2, the slow down you see is significantly less.
jax<=0.6.1 gives ~3ms per function call under jit
jax>=0.6.2 gives ~500μs per function call under jit
Trace given below for jax==0.6.2.
There is still a glaringly obvious cache_miss being hit, but the compile function is harder to find. It is there, but it only runs for ~1μs (too short to see on the trace without zooming in). I found nothing obviously relating to this in the jax changelog for 0.6.2.
This could almost pass as a half-fix... it's quite a big speed up. But the cache_miss still being silently triggered definitely makes me uneasy. And the GPU is still being severely under-utilised.
Okay, got it!
I think the main few questions I'd be curious about:
- do you see this when using exclusively
jax.Arrays as inputs across the JIT boundary? (E.g. no Lineax linear operators) - do you see this when using
jax.jitinstead ofeqx.filter_jit? - do you see this when setting
os.ENVIRON["EQX_ON_ERROR"] = "nan"? (Which changes the behaviour ofeqx.error_if, which usesjax.pure_callback, andpure_callbackis sometimes guilty of inducing weird behaviour.)
For context, the main possibilities that I am trying to disambiguate between here are:
- whether are hitting a slow path on the JIT-dispatch boundary;
- whether the slowdown is due to what is being passed across that boundary or due to the computation itself;
- if it is the computation itself, is due to an
eqx.error_if/jax.pure_callback, which are frequently suspicious in performance-related concerns; - if it is a callback, is the overhead coming from the presence of the callback or from actually running the callback.
Got the most unhelpful answer:
- yes,
- yes,
- and yes
import lineax as lx
import jax
from jax import random as jr
import os
os.environ["EQX_ON_ERROR"] = "nan"
print("JAX version:", jax.__version__)
print("Lineax version:", lx.__version__)
@jax.jit
def f(matrix, vector):
print("COMPILING")
operator = lx.MatrixLinearOperator(matrix)
return lx.linear_solve(operator, vector).value
# setting up inputs
matrix = jr.normal(jr.PRNGKey(0), (3, 3))
vector = jr.normal(jr.PRNGKey(1), (3,))
print(type(matrix), type(vector))
# compiling
f(matrix, vector).block_until_ready()
# running five times and tracing with perfetto
with jax.profiler.trace("/tmp/profile-data", create_perfetto_link=True):
for i in range(5):
f(matrix, vector).block_until_ready()
> JAX version: 0.6.0
> Lineax version: 0.0.8
> <class 'jaxlib.xla_extension.ArrayImpl'> <class 'jaxlib.xla_extension.ArrayImpl'>
> COMPILING
> 2025-08-24 22:29:26.113599: E external/xla/xla/python/profiler/internal/python_hooks.cc:412] Can't import tensorflow.python.profiler.trace
> 2025-08-24 22:29:26.133550: E external/xla/xla/python/profiler/internal/python_hooks.cc:412] Can't import tensorflow.python.profiler.trace
> Open URL in browser: https://ui.perfetto.dev/#!/?url=http://127.0.0.1:9001/perfetto_trace.json.gz
> 127.0.0.1 - - [24/Aug/2025 22:29:30] code 404, message File not found
> 127.0.0.1 - - [24/Aug/2025 22:29:30] "POST /status HTTP/1.1" 404 -
> 127.0.0.1 - - [24/Aug/2025 22:29:30] code 501, message Unsupported method ('OPTIONS')
> 127.0.0.1 - - [24/Aug/2025 22:29:30] "OPTIONS /status HTTP/1.1" 501 -
> 127.0.0.1 - - [24/Aug/2025 22:29:30] code 501, message Unsupported method ('OPTIONS')
> 127.0.0.1 - - [24/Aug/2025 22:29:30] "OPTIONS /perfetto_trace.json.gz HTTP/1.1" 501 -
> 127.0.0.1 - - [24/Aug/2025 22:29:30] "GET /perfetto_trace.json.gz HTTP/1.1" 200 -
Edit: Confirmed consistent result for jax==0.6.2 as well.
Okay! That actually really helps narrow things down: there's just not really much left after this!
The basic plan now is to substitute-and-simplify. In this case, copy-paste the definition of linear_solve into this script (so we gain a dependency on Lineax internals but lose the direct dependency on lx.linear_solve), and then see how much of the code can be deleted. For whatever is remaining, copy-paste in their definitions and repeat. At some point we'll find a MWE that doesn't use Lineax at all.
In particular this approach is significantly aided by the fact that there is only one transform (jax.jit), so all the complexity we have around custom primitives / autodiff / etc is all not needed. The eqxi.filter_primitive_bind(linear_solve_p, ... ) can just be substituted with _linear_solve_impl(..., check_closure=False).
If you're up for trying that out, then I'd be very happy to keep supporting this. If you don't have the time or inclination then no worries, and I'll try to pick this up at some point when I find some time :)
Thanks @patrick-kidger. This was very helpful.
You were right to suspect jax.pure_callback. I believe that's our perpetrator.
Here is a pure jax==0.6.0 MRE:
import jax
import jax.numpy as jnp
import numpy as np
def f_host(x):
# call a numpy (not jax.numpy) operation:
return np.sin(x).astype(x.dtype)
@jax.jit
def f(x):
result_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)
return jax.pure_callback(f_host, result_shape, x, vmap_method="sequential")
x = jnp.arange(5.0)
# compiling
f(x).block_until_ready()
with jax.profiler.trace("/tmp/profile-data", create_perfetto_link=True):
for i in range(5):
f(x).block_until_ready()
Changing the EQX_ON_ERROR didn't help – I tried all options and none of them prevented a cache_miss.
The bandaid solution that I've tested with my own code is just to pass lx.linear_solve(..., throw=False). I tried to debug the jax.pure_callback for a bit, but I'm starting to bite off more than I can chew...
Thanks again for the help, I will open an issue on jax tomorrow.
Aha, awesome! Super glad to have identified a JAX-only MWE. That's great work!
Changing the EQX_ON_ERROR didn't help – I tried all options and none of them prevented a cache_miss.
This makes sense, since this is specifically a thing that affects eqx.error_if, which is not part of this MWE.
The bandaid solution that I've tested with my own code is just to pass lx.linear_solve(..., throw=False). I tried to debug the jax.pure_callback for a bit, but I'm starting to bite off more than I can chew...
Yup, this sounds like a good workaround in the mean time.
We have the occasional pure_callback scattered around in a few other places which might cause issues, but if so then now that we know the root cause, it should be easy enough to identify / workaround.
Adding that error_if is also expensive, so if you're interested in minimising runtimes it is a good idea to disable it and check the result afterwards.
I frequently have a use-case where it is enough to just filter for successful solutions with a boolean mask obtained using
success = solution.result == lx.RESULTS.successful
You can also check for the fraction of successful solutions with jnp.sum(success) / success.size.
If you expect each linear solve to be successful your routine may differ, of course. So throw=False is a workaround that may have some benefits too, especially if you're vmapping the solve.
Alas... I am not quite out of the woods yet. So close.
The cache miss still appears in lx.linear_solve when taking gradients through it, even when passing throw=False.
e.g.
@jax.jit
@jax.grad
def f(matrix, vector):
operator = lx.MatrixLinearOperator(matrix)
return lx.linear_solve(operator, vector, throw=False).value.sum()
Can't be 100% sure if this is a result of the same jax bug – I'd be happy to sift through the function again to find the root cause but I'm not sure how the eqxi.filter_primitive_bind interacts with jax.grad. So not sure how to progress from there. The cache miss is definitely coming from within this line:
def linear_solve(...):
...
solution, result, stats = eqxi.filter_primitive_bind(
linear_solve_p, operator, state, vector, options, solver, throw
)
...
Ah, this is an easy one; we unconditionally use throw=True in the backward pass. (We have nowhere to pipe the error to.)
You could adjust your copy of Lineax to work around this whilst we wait for a better fix.
An alternative workaround might be to adjust the nan version of error_if not to use callbacks (somehow) so that you can set EQX_ON_ERROR=nan to disable them all regardless of location.
Awesome. Worked! Thanks @patrick-kidger
That's a nicer looking trace. GPU go brrrr
Ah, this is an easy one; we unconditionally use throw=True in the backward pass. (We have nowhere to pipe the error to.)
On this note, I think it would be useful to have the option to also set throw=False on the backward pass. When using an iterative solver, failure to converge on the backward pass basically just means that you're getting an approximate gradient. There are plenty of optimizers that can handle approximate gradients, but very few that can handle NaNs or errors.
hat's a nicer looking trace. GPU go brrrr
Awesome!
When using an iterative solver, failure to converge on the backward pass basically just means that you're getting an approximate gradient.
Where here 'approximate' can be a very loose word indeed ;)
On this note, I think it would be useful to have the option to also set throw=False on the backward pass.
I think I'd be happy to add this as an undocumented option for those power users who care enough to go spelunking through the source code / github issues to find it! (I'd like to avoid adding public 'make the error go away' buttons as traps for the unwary user.)