devjalx
devjalx
The problem has persisted for half a year now. I was hoping back then that this was just a temporary issue, and since the problem didn't resolve after I updated...
Are you referring to the issue (https://github.com/jax-ml/jax/issues/30185)? In that case, the spikes occur during recompilation, not during the call of an already jitted function. Having these spikes during compilation is...
> Got it, sorry! Have you confirmed that no recompilation is taking place on the Equinox side of things, e.g. by checking this with `eqx.debug.assert_max_traces`? (Just to make sure.) Just...
Issue remains in 0.11.12 0.11.11 and 0.10.11 (found this in an old venv). I have issues installing older versions because I cannot resolve dependencies between JAX and jaxlib. Unfortunately, today,...
Yes, the batch-size and the ViT are small and the GPU (H100) is very strong, so an increase in overhead has an overly large influence. To add a guess: As...
I ran gc.get_count() in the loop. These are my results: **Iteration duration: 0.007998943328857422 gc.get_count(): (11, 0, 17) Iteration duration: 0.049581050872802734 gc.get_count(): (11, 7, 0)**
Hi, sorry for the late reply. @Artur-Galstyan: thank you for reproducing this across different python versions. So this seems to be not on the python side. I dug a bit...