diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Memory growing when integrating lots (different) functions

Open fhchl opened this issue 1 year ago • 3 comments

Hi again 😃

I noticed that memory seems to grow indefinitely when integrating lots of newly defined functions. In the following example, memory grows at each iteration if vector_field is defined inside solve_ode. I expect that compilation is triggered at each iteration in that case, and one could define vector_field globally to avoid that. But I also expected that the compilation cache for the locally defined function would be cleared at the end of solve_ode.

My use case here is benchmarking tons of different ODE models on lots of different data in a single script. While I am trying to take care to free the equinox modules containing the vector_fields, deleting those objects does not seem to clear that compilation cache, hence the system is running out of memory (that's at least my guess on what is happening).

I noticed the clear_cache() snipped in the test suite, but applying it here doesn't seem to free the right resources.

import gc, psutil, sys
import jax.numpy as jnp
import diffrax as dfx

def solve_ode():
  def vector_field(t, x, _): return -0.1*x
  t = jnp.arange(4800)
  sol = dfx.diffeqsolve(
    terms=dfx.ODETerm(vector_field),
    solver=dfx.Dopri5(),
    t0=t[0], t1=t[-1], dt0=t[1], y0=1., max_steps=len(t))
  return sol.ys

def clear_caches():
  process = psutil.Process()
  if process.memory_info().rss > 0.5 * 2**30: # >500MB memory usage
    for module_name, module in sys.modules.items():
      if module_name.startswith("jax"):
        for obj_name in dir(module):
          obj = getattr(module, obj_name)
          if hasattr(obj, "cache_clear"):
            obj.cache_clear()
    gc.collect()
    print("Cache cleared")

# loop that grows memory
for i in range(100):
  res = solve_ode()
  clear_caches()
  print(f"Process uses {psutil.Process().memory_info().rss / (1024 * 1024)} MB memory")

fhchl avatar Aug 06 '22 08:08 fhchl

Hey there! Hmm, this kind of scenario is rather annoying, I agree. I think this is probably a JAX issue, in that it is failing to properly clear its compilation cache.

Good find locating the cache-clearing hack in the test suite. This would have been my first suggestion to try, but it is definitely a hack.

I'd recommend opening an issue on the JAX GitHub issue page and see if there is any further advice to be found there. (If nothing else to let them know that this is a use case that people have.)

A possible workaround may be to restart the Python process.

In passing, and unrelated to this issue - judging from your example, you may like stepsize_controller=ConstantStepSize(compile_steps=True), (or StepTo(compile_steps=True)) which will specialise the compilation on the exact number of steps being taken, and in doing so may reduce compile/runtimes. (At the expense of necessarily recompiling if this changes.)

patrick-kidger avatar Aug 07 '22 13:08 patrick-kidger

Thanks for these comments (definitely will try compile_steps)!

I pocked around a little bit more and found out that the main culprit seems to be the cache at dfx.diffeqsolve._cache! Clearing that one and the JAX caches--et voila--only a tiny memory leak left! 🥳

def clear_caches():
  ...
  gc.collect()
  dfx.diffeqsolve._cached.clear_cache()
  print("Cache cleared")

Maybe equinox.compile_cache could be improved by caching weak references instead (e.g. like here: https://github.com/google/jax/pull/11461) ?

fhchl avatar Aug 08 '22 09:08 fhchl

Nicely found!

So unfortunately, I don't think this is something that can be fixed in Equinox/Diffrax. The memory usage here is due to jax.jit (not equinox.filter_jit or equinox.compile_utils.compile_cache) producing a new element in the dfx.diffeqsolve._cached cache each time it is called with a different static argument (namely each new vector_field that is defined).

Ideally jax.jit would track weakrefs to its static arguments, and drop those compilation results which now correspond to dead weakrefs. I'd suggest opening an issue on the JAX GitHub issue tracker.

patrick-kidger avatar Aug 08 '22 16:08 patrick-kidger