Equinox has spikes in training step durations due to garbage collector
Hi all,
this is related to https://github.com/jax-ml/jax/issues/30185, however, the issue occurs when running already jitted functions in combination with equinox and not during compilation.
When training neural networks in equinox, I noticed that some steps take unusually longer than others in a periodic manner. This gets worse the larger the neural networks and the more GPUs are in use, making scaling up impossible. As an example: When I trained a 2mio parameter transformer on 8 GPUs, every 100th step took around a minute while all others took under a second. This problem occurs over different machines, python, jax and equinox versions. Disabling the garbage collector solves this issue, so apparently the garbage collector has problems resolving cylcic references created by the training step.
Here is an example for a small Vision Transformer:
The problem is difficult to replicate with a small working example. I took the Vit-Training example of equinox and removed dataloading to make sure it is not caused by dataloading. You can find the code here: https://github.com/AlexGraefe/jax_training_spikes. There, I also implemented ViT-training in NNX to make sure it is not JAX itself.
Does someone see a similar behavior and has a workaround for this? Disabling the garbage collector might not be the wisest solution....
Can you try this with different versions of JAX? This seems like an underlying JAX issue rather than an Equinox one, and perhaps it only triggers in certain versions.
The problem has persisted for half a year now. I was hoping back then that this was just a temporary issue, and since the problem didn't resolve after I updated my JAX installations to CUDA 13, I decided to explore the source of it.
I tested several venvs with other JAX versions I still have installed on my machines (0.8.0, 0.4.38, 0.7.2). I also tested Python 3.10, 3.11 and 3.12. They all have the same issue.
How is the behaviour with Equinox different from the behaviour with JAX? Since in JAX you also see these spikes, but looking at the graphs it is not completely clear to me if it is comparable, other than qualitatively.
Are you referring to the issue (https://github.com/jax-ml/jax/issues/30185)? In that case, the spikes occur during recompilation, not during the call of an already jitted function. Having these spikes during compilation is acceptable, as it doesn't happen very often. But these spikes during jitted functions completely ruin performance, especially for large-scale training. During these spikes, the GPU is not used, meaning something is happening on CPU side.
It might not be Equinox directly. Perhaps it triggers something in JAX that creates this behavior. The problem is that I do not see this behavior with jitted functions using only JAX, and I also do not see it with Flax NNX. I therefore cannot create an issue on JAX's repository, as they will most likely refer me to you. If this behavior is caused by JAX, I need to pinpoint exactly where it originates and create a pure JAX example before I can create an issue.
For reference, here is a ViT trained with NNX. There you also have small spikes, but after every spike up, you get a corresponding spike down, and they compensate for each other.
Got it, sorry! Have you confirmed that no recompilation is taking place on the Equinox side of things, e.g. by checking this with eqx.debug.assert_max_traces? (Just to make sure.)
The problem has persisted for half a year now.
Can you trace it back to an Equinox release? If so this would be extremely helpful in narrowing things down.
I have a hunch, following it now with your example - maybe I get something :)
Got it, sorry! Have you confirmed that no recompilation is taking place on the Equinox side of things, e.g. by checking this with
eqx.debug.assert_max_traces? (Just to make sure.)
Just checked it. The code compiles once.
Can you trace it back to an Equinox release? If so this would be extremely helpful in narrowing things down.
I would say beginning 2025 so 0.11.12. Before it, I did not notice, but also did not do performance critical things, so might have missed it :).
Followed up on my hunch - it seems that our filtered transformations have something to do with it.
When using jax.jit and closing over the static elements of the model after an eqx.partition(model, eqx.is_array), then I still get spikes but much fewer and farther between.
In contrast, on the same machine using eqx.filter_{jit, value_and_grad}, I get what you get with
There were some changes to the filtered-transformations / filter_jit machinery earlier this year, particularly in 0.12.0 and 0.12.2, to reduce the small (additive) overhead these can produce in comparison to their JAX equivalents from "very small" to "tiny". (0.11.12 was a compatibility release.) Could you try to narrow down when this started in terms of Equinox versions?
FWIW I ran this with Python 3.10 and JAX 0.6.2, downgrading dependencies accordingly.
Issue remains in 0.11.12 0.11.11 and 0.10.11 (found this in an old venv). I have issues installing older versions because I cannot resolve dependencies between JAX and jaxlib. Unfortunately, today, I wont be able to work on it :(.
0.11.11/0.11.12:
0.10.11 (distances are higher, but the time per step is smaller, so it is to be expected when the gc is called periodically):
Still, it seems like 0.10.11 is more performant, judging from the shorter durations of the training step? If the training step takes half the time, then this also does not fully make up for the increased distances.
So it looks like this problem got worse win recent versions, either due to a change in JAX/XLA that only has this effect in combination with Equinox, or due to a change on our end.
Yes, the batch-size and the ViT are small and the GPU (H100) is very strong, so an increase in overhead has an overly large influence.
To add a guess: As everything is immutable, whenever we change a PyTree, we create a new object. As we copy references of leaves during this, references are shared between objects. Maybe, something inside PyTrees/eqx.Module creates a lot of cyclic references during this process and the Garbage Collector has problems resolving it.
Some kind of cyclic garbage during our jit-dispatch seems like a plausible reason, I agree. It should probably be possible to identify this with enough gc magic.
I ran gc.get_count() in the loop. These are my results:
Iteration duration: 0.007998943328857422 gc.get_count(): (11, 0, 17) Iteration duration: 0.049581050872802734 gc.get_count(): (11, 7, 0) <------------------------------ Iteration duration: 0.008635759353637695 gc.get_count(): (11, 1, 1) Iteration duration: 0.00812220573425293 gc.get_count(): (11, 7, 1) Iteration duration: 0.00818634033203125 gc.get_count(): (11, 1, 2) Iteration duration: 0.007991790771484375 gc.get_count(): (11, 7, 2) Iteration duration: 0.008106708526611328 gc.get_count(): (11, 1, 3) Iteration duration: 0.007803201675415039 gc.get_count(): (11, 7, 3) Iteration duration: 0.010233879089355469 gc.get_count(): (11, 1, 4) Iteration duration: 0.00926065444946289 gc.get_count(): (11, 7, 4) Iteration duration: 0.009359598159790039 gc.get_count(): (11, 1, 5) Iteration duration: 0.009006261825561523 gc.get_count(): (11, 7, 5) Iteration duration: 0.009138822555541992 gc.get_count(): (11, 1, 6) Iteration duration: 0.008296012878417969 gc.get_count(): (11, 7, 6) Iteration duration: 0.008239984512329102 gc.get_count(): (11, 1, 7) Iteration duration: 0.007908344268798828 gc.get_count(): (11, 7, 7) Iteration duration: 0.008116960525512695 gc.get_count(): (11, 1, 8) Iteration duration: 0.007721900939941406 gc.get_count(): (11, 7, 8) Iteration duration: 0.00793147087097168 gc.get_count(): (11, 1, 9) Iteration duration: 0.007572650909423828 gc.get_count(): (11, 7, 9) Iteration duration: 0.007769584655761719 gc.get_count(): (11, 1, 10) Iteration duration: 0.007535219192504883 gc.get_count(): (11, 7, 10) Iteration duration: 0.007733583450317383 gc.get_count(): (11, 1, 11) Iteration duration: 0.008194446563720703 gc.get_count(): (11, 7, 11) Iteration duration: 0.01064920425415039 gc.get_count(): (11, 1, 12) Iteration duration: 0.008659601211547852 gc.get_count(): (11, 7, 12) Iteration duration: 0.00844717025756836 gc.get_count(): (11, 1, 13) Iteration duration: 0.008003473281860352 gc.get_count(): (11, 7, 13) Iteration duration: 0.008090496063232422 gc.get_count(): (11, 1, 14) Iteration duration: 0.007787942886352539 gc.get_count(): (11, 7, 14) Iteration duration: 0.008152008056640625 gc.get_count(): (11, 1, 15) Iteration duration: 0.008344650268554688 gc.get_count(): (11, 7, 15) Iteration duration: 0.008121967315673828 gc.get_count(): (11, 1, 16) Iteration duration: 0.00774693489074707 gc.get_count(): (11, 7, 16) Iteration duration: 0.00788426399230957 gc.get_count(): (11, 1, 17) Iteration duration: 0.007604122161865234 gc.get_count(): (11, 7, 17) Iteration duration: 0.007926464080810547 gc.get_count(): (11, 1, 18) Iteration duration: 0.007570743560791016 gc.get_count(): (11, 7, 18) Iteration duration: 0.0077741146087646484 gc.get_count(): (11, 1, 19) Iteration duration: 0.007539987564086914 gc.get_count(): (11, 7, 19) Iteration duration: 0.04712271690368652 gc.get_count(): (15, 0, 0) <------------------------------
You can see whenever the last number of get_count gets reset to 0, i.e., the objects are collected, the iteration duration spikes. If I am not mistaken, this is the number for the old generation (i.e., objects that survived the first gc scan) that were already scanned. I guess the point where this happens is called "full scavange": https://github.com/python/cpython/blob/3.14/InternalDocs/garbage_collector.md#Optimization-incremental-collection, where the gc scans the entire heap for garbaged objects.
However, I do not understand why an immutable object that is not used anymore should survive the first scan (when it is in the young generation). Maybe something is creating a "linked list" type reference scheme. If one end is in the old visited generation, the linked list will not be collected untill the next full scavange.
Hm, so I tried this on different Python versions (3.12, 3.13 and 3.14) using JAX 0.8.0 on a 5090 and across those versions it looks a bit different as well, not sure how helpful this is though:
3.12:
3.13:
3.14:
(Don't mind the commit hash, it's the current main one. I wanted to git bisect this but dependency hell got hands and defeated me)
Hi,
sorry for the late reply. @Artur-Galstyan: thank you for reproducing this across different python versions. So this seems to be not on the python side.
I dug a bit deeper. Apparently just flattening an equinox.Module twice in the same context causes the behavior (this is implicitly done in eqx.partition and thus causes the issues when using a function with filter_jit). Here is an example (pushed it into the exemplary repo under example_flatten.py):
def flatten_twice(model): # spikes
unflattened = jax.tree.flatten(model)
unflattened = jax.tree.flatten(model)
def flatten_once(model): # no spikes
unflattened = jax.tree.flatten(model)
Can you reproduce this behavior?