diffrax
diffrax copied to clipboard
Memory growing when integrating lots (different) functions
Hi again 😃
I noticed that memory seems to grow indefinitely when integrating lots of newly defined functions. In the following example, memory grows at each iteration if vector_field
is defined inside solve_ode
. I expect that compilation is triggered at each iteration in that case, and one could define vector_field
globally to avoid that. But I also expected that the compilation cache for the locally defined function would be cleared at the end of solve_ode
.
My use case here is benchmarking tons of different ODE models on lots of different data in a single script. While I am trying to take care to free the equinox modules containing the vector_fields, deleting those objects does not seem to clear that compilation cache, hence the system is running out of memory (that's at least my guess on what is happening).
I noticed the clear_cache()
snipped in the test suite, but applying it here doesn't seem to free the right resources.
import gc, psutil, sys
import jax.numpy as jnp
import diffrax as dfx
def solve_ode():
def vector_field(t, x, _): return -0.1*x
t = jnp.arange(4800)
sol = dfx.diffeqsolve(
terms=dfx.ODETerm(vector_field),
solver=dfx.Dopri5(),
t0=t[0], t1=t[-1], dt0=t[1], y0=1., max_steps=len(t))
return sol.ys
def clear_caches():
process = psutil.Process()
if process.memory_info().rss > 0.5 * 2**30: # >500MB memory usage
for module_name, module in sys.modules.items():
if module_name.startswith("jax"):
for obj_name in dir(module):
obj = getattr(module, obj_name)
if hasattr(obj, "cache_clear"):
obj.cache_clear()
gc.collect()
print("Cache cleared")
# loop that grows memory
for i in range(100):
res = solve_ode()
clear_caches()
print(f"Process uses {psutil.Process().memory_info().rss / (1024 * 1024)} MB memory")
Hey there! Hmm, this kind of scenario is rather annoying, I agree. I think this is probably a JAX issue, in that it is failing to properly clear its compilation cache.
Good find locating the cache-clearing hack in the test suite. This would have been my first suggestion to try, but it is definitely a hack.
I'd recommend opening an issue on the JAX GitHub issue page and see if there is any further advice to be found there. (If nothing else to let them know that this is a use case that people have.)
A possible workaround may be to restart the Python process.
In passing, and unrelated to this issue - judging from your example, you may like stepsize_controller=ConstantStepSize(compile_steps=True)
, (or StepTo(compile_steps=True)
) which will specialise the compilation on the exact number of steps being taken, and in doing so may reduce compile/runtimes. (At the expense of necessarily recompiling if this changes.)
Thanks for these comments (definitely will try compile_steps
)!
I pocked around a little bit more and found out that the main culprit seems to be the cache at dfx.diffeqsolve._cache
! Clearing that one and the JAX caches--et voila--only a tiny memory leak left! 🥳
def clear_caches():
...
gc.collect()
dfx.diffeqsolve._cached.clear_cache()
print("Cache cleared")
Maybe equinox.compile_cache
could be improved by caching weak references instead (e.g. like here: https://github.com/google/jax/pull/11461) ?
Nicely found!
So unfortunately, I don't think this is something that can be fixed in Equinox/Diffrax. The memory usage here is due to jax.jit
(not equinox.filter_jit
or equinox.compile_utils.compile_cache
) producing a new element in the dfx.diffeqsolve._cached
cache each time it is called with a different static argument (namely each new vector_field
that is defined).
Ideally jax.jit
would track weakrefs to its static arguments, and drop those compilation results which now correspond to dead weakrefs. I'd suggest opening an issue on the JAX GitHub issue tracker.