Segfault on Linux with JAX 0.5.2 when filter-jitting jvp-of-jvp-of-traced-multiplication
Hi Patrick,
I managed to whittle the issue uncovered in https://github.com/patrick-kidger/lineax/pull/140 down to the following MWE. It seems like we don't need Lineax, but do need Equinox to produce this error. It only shows up after a certain amount of operations have been performed - I jacked the number up to as much as 5000 but did not get the same error without using equinox.filter_jit. I could remove Equinox for the JVPs and keep the error.
It persists when the linear solve is trivial, but does not seem to occur with something like jnp.matmul.
import functools as ft
import equinox as eqx
import jax
import jax.numpy as jnp
USE_EQUINOX = True
num_iterations = 1000
def jvp_jvp_impl():
mat = jnp.eye(3)
vec = jnp.ones(3)
jvp_solve = lambda v: jax.jvp(jnp.linalg.solve, (mat, v), (mat, v))
jvp_jvp_solve = ft.partial(jax.jvp, jvp_solve)
if USE_EQUINOX:
(out, t_out), (minus_out, tt_out) = eqx.filter_jit(jvp_jvp_solve)([vec], [vec])
else:
(out, t_out), (minus_out, tt_out) = jax.jit(jvp_jvp_solve)([vec], [vec])
if __name__ == "__main__":
for _ in range(num_iterations):
jvp_jvp_impl()
print(".", end="", flush=True)
Ah awesome! We're making real progress here.
Do we specifically need the jvp-of-jvp? That's a weird high-order construct so I'm willing to believe something is up there! (Given that jvp-of-linsolve involve further linsolves, then perhaps we could replicate it by having more iterations and no jvp?)
Does this still occur if you write foo = eqx.filter_jit(jvp_jvp_solve) outside the loop, and then call foo([vec], [vec]) inside the loop? That is, is the error from decorating multiple times or from calling multiple times?
What about if you insert equinox.clear_cache() calls in there?
Either way, it looks like the next step is to tear apart the implementation of filter_jit, to reduce this down to an example using just jax.jit.
Quick question: why can we donate the first argument here: https://github.com/patrick-kidger/equinox/blob/8191b113df5d985720e86c0d6292bceb711cbe94/equinox/_jit.py#L63
That is one of the differences that jump out at me, but I cannot directly replicate it using jax.jit, I will get the (expected) invalid argument error. I tried running it with just jax.jit and default arguments for more than 100k iterations, and I did not get a segfault for that. But I do get one after about 5k iterations when using equinox.filter_jit.
Agreed on all the other points, will try them later today!
Under-the-hood eqx.filter_jit splits its arguments into three pieces, that are then passed to a jax.jit-wrapped function. Those pieces are (a) donated arrays; (b) non-donated arrays; (c) static non-arrays. It may be the case that piece (a) is empty.
I tried a few more things.
- We only need a multiplication of two traced arrays and can trigger with
lambda a, b: a*b, but not withlambda a, b: a + b. - Wrapping the arrays in
stop_gradientrescues the segmentation fault - This is a compile time error that only occurs if we compile repeatedly and do not clear the caches (using either
eqx.clear_caches()orjax.clear_caches(). - A double JVP appears to be necessary, I could not get it to trigger with a single JVP and a generous number of iterations.
Next steps: I'm going to take a look at what gets cached during filtered compilation and anything hashable created in the process, and take a look at what changed in JAX and jaxlib/XLA that could possibly affect this, since we did not get this before.
(using either eqx.clear_caches() or jax.clear_caches().
The optionality here is a little surprising to me -- these clear totally unrelated caches, I think.
This is the memory growth with eqx.filter_jit and without cache clearing
it keeps growing, until it can't allocate memory anymore. With jax.jit or either of the cache-clearing operations, the memory allocation remains stable, with a peak of 2.25 MB.
So I can see that you're creating a new jvp_solve function on every iteration. I think we keep strong references to the functions in the cache.
FWIW https://github.com/jax-ml/jax/issues/16226 is potentially relevant (although it deals primarily with arguments rather than the wrapped function).
So I think that might just be red herring based on the way this example is set up?
So I think that might just be red herring based on the way this example is set up?
Ah, that could be. I was just checking how far back this behaviour persists, and it goes quite far back (in terms of equinox versions) and did not throw seg faults before. However, this thing has to do with repetition, and I think it has something to do with memory too. For example, yesterday morning I needed to run this script for more than 100k iterations to get the first seg fault, and thereafter I got seg faults after about 5k iterations, until that dropped down to 4k.
Repeatedly creating functions that do more or less the same thing and compiling these individually is not the usual use case, but it is what we do in the tests, and doing things this way seems to be required to trigger this.
I was able to get an error message that is a bit more informative:
This file had changes to its imports two months ago, about a week after the lineax CI ran (without errors) for the changes I introduced back then. This is the commit. The changes, as described, seem relevant.
I'm going to check this out tomorrow. With respect to more and more memory being allocated / strong references in the cache: do you think that is expected behaviour and not something we could/should prevent in equinox?
Aha awesome! That's excellent progress.
I think I'd start by settling for a repro that doesn't use Equinox. Then we'll have something that we can send to the JAX/XLA team, and also have something concrete we can test workarounds on.
Trying to! A really weird thing: I can make filter_jit an alias of jax.jit and comment out EVERYTHING ELSE in equinox._jit.py, and still get the same memory profile as above, with progressive amounts of memory allocated throughout the iterations, which does not happen with just jax.jit. I also still get the segfault if I am doing this.
I'm trying to figure out what we're doing in equinox for this specific use case that jax is not doing, so I can replicate that and get the same error without equinox in the picture. But now I'm wondering if it is something that jax is now doing and that we should be doing: maybe we're missing something, instead of having added something.
Alright, I was masking one effect with another. But now I got some clarity!
There are two ways to make this fail - one without equinox. Without equinox, it is enough to append the function we're creating (jvp_jvp_solve) to a list at each iteration and keep this around. With equinox, we fail as soon as we use the wrapped function created in _filter_jit_cache. It does not matter whether we write that somewhere, as long as we use it.
They fail at the same time, around 5100 iterations, and the memory allocated until then is the same as well.
Aha, interesting! So either way it looks like it's probably to do with keeping too much around in memory.
So I think that gives us two clear action items: (a) report the MWE to the JAX/XLA team; (b) work around this in Lineax by adding the following to the failing test file:
@pytest.fixture(autouse=True)
def _clear_cache():
eqx.clear_cache()
Agreed.