jaxopt icon indicating copy to clipboard operation
jaxopt copied to clipboard

Possible memory leak when calling solver.run multiple times

Open alucantonio opened this issue 2 years ago • 17 comments

I am trying to solve a problem where solver.run is called multiple times to minimize a series of functions while varying a parameter. Using memory_profiler I can see that the allocated memory increases each time the function solver.run is called and never decreases.

Here is a minimal example to reproduce the issue:

import jax.numpy as jnp
import jaxopt
from memory_profiler import profile

@profile
def optimize(min):

    def obj(x, min):
        return jnp.square(x-min).sum()

    x0 = jnp.zeros(1)
    mm = jnp.array(min)

    solver = jaxopt.LBFGS(obj, maxiter=100)
    x = solver.run(x0, min=mm).params[0]
    print(x)

for i in range(10):
    optimize(i)

And here is the corresponding plot of the allocated memory: Figure_1

Can you please confirm the issue or provide a solution for that? Thanks. Alessandro

alucantonio avatar Jan 18 '23 09:01 alucantonio

Is it specific to LBFGS or does it happen with any solver?

mblondel avatar Jan 23 '23 09:01 mblondel

Hi, I have experienced the issue with LBFGS and GradientDescent. The increase in memory is less evident with GradientDescent, but it is still there. I believe the issue does not depend on the solver.

alucantonio avatar Jan 23 '23 09:01 alucantonio

Can you also check if the jit and unroll options to LBFGS have any impact on this? Depending on these options, a different loop implementation is used.

Normally, solver objects don't store anything, so I'm not sure where this could come from...

CC @fabianp, @froystig

mblondel avatar Jan 24 '23 15:01 mblondel

Setting jit=False and unroll=True or jit=True and unroll=False and using LBFGS still produces an increase in memory after each call of solver.run.

alucantonio avatar Jan 25 '23 08:01 alucantonio

I don't have a solution for this, but can confirm that it happens also with the update API, i.e., when updates are run inside a for loop:

def optimize(min):

    def obj(x, min):
        return jnp.square(x-min).sum()

    x0 = jnp.zeros(1)
    mm = jnp.array(min)

    solver = jaxopt.LBFGS(obj, implicit_diff=False, maxiter=100)
    state = solver.init_state(x0, min=mm)
    jitted_update = jax.jit(solver.update)
    params = x0
    for _ in range(solver.maxiter):
        params, state = jitted_update(params, state, min=mm)

fabianp avatar Feb 08 '23 13:02 fabianp

Some updates on my investigations.

  1. Upon @mblondel's idea, I set eq=True in the definition of LBFGS. It didn't help.
  2. I also modified the LBFGS class to remote the dataclass decorator. It didn't help.
  3. I'm inclined to think the issue is in the update method. The following code that constructs the solver but doesn't perform the updates doesn't have the memory leak:
import jax.numpy as jnp
import jaxopt
import jax
import gc
import time


def optimize(min):

    def obj(x, min):
        return jnp.square(x-min).sum()

    x0 = jnp.zeros(1)
    mm = jnp.array(min)

    solver = jaxopt.LBFGS(obj)
    state = solver.init_state(x0, min=mm)
    jitted_update = jax.jit(solver.update)
    params = x0
    for _ in range(solver.maxiter):
        pass
    #     params, state = jitted_update(params, state, min=mm)
    time.sleep(1)

for i in range(10):
    optimize(i)
    gc.collect()

However, if I uncomment the lines inside the for loop (even for just 1 iteration), the leak comes back

fabianp avatar Feb 14 '23 13:02 fabianp

In your example, is there still a leak if the update is not jitted?

froystig avatar Feb 14 '23 17:02 froystig

yeah, although there's a small decrease at the end that could mean it's recuperating some memory.

This is without jitting: image

and with jitting: image

As you can see, it's also using a lot more memory when it's not jitting. Not sure what to make of that

fabianp avatar Feb 14 '23 18:02 fabianp

Thanks for the investigations. I would like to know whether this behavior can be considered as a bug and whether there is any plan to fix it.

alucantonio avatar Feb 23 '23 08:02 alucantonio

Yes to both. Seems like a bug and should be fixed (although we're all spread too thin, I wouldn't know how to set a timeline on it)

On Thu, Feb 23, 2023, 09:49 Alessandro Lucantonio @.***> wrote:

Thanks for the investigations. I would like to know whether this behavior can be considered as a bug and whether there is any plan to fix it.

— Reply to this email directly, view it on GitHub https://github.com/google/jaxopt/issues/380#issuecomment-1441389521, or unsubscribe https://github.com/notifications/unsubscribe-auth/AACDZB3HYXKRYPU5PVS3UVTWY4QAHANCNFSM6AAAAAAT62H7U4 . You are receiving this because you were mentioned.Message ID: @.***>

fabianp avatar Feb 23 '23 09:02 fabianp

Hi, has been there any progress on this issue?

alucantonio avatar Apr 04 '23 06:04 alucantonio

This behavior can be avoided using the newly implemented jax.clear_caches() in jax (thanks @froystig !).

For example, the code below doesn't have the ever increasing profile. Instead, it has the more expected initial increment and then plateau:

Figure_1

import jax.numpy as jnp
import jaxopt
import jax
import gc
import time


def optimize(min):

    def obj(x, min):
        return jnp.square(x-min).sum()

    x0 = jnp.zeros(1)
    mm = jnp.array(min)

    solver = jaxopt.LBFGS(obj, maxiter=100)
    x = solver.run(x0, min=mm).params[0]
    print(x)

for i in range(10):
    optimize(i)
    jax.clear_caches()

fabianp avatar May 19 '23 06:05 fabianp

I'm going to close the issue for now, but please reopen if problem persist (BTW you might need the development version of jax for the clear_caches() function)

fabianp avatar May 19 '23 06:05 fabianp

It's nice to have a workaround but shouldn't garbage collection be able to do this automatically?

mblondel avatar May 19 '23 09:05 mblondel

Maybe, but at this point it seems more of an issue concerning jax than jaxopt, wdyt?

On Fri, May 19, 2023, 11:30 Mathieu Blondel @.***> wrote:

It's nice to have a workaround but shouldn't garbage collection be able to do this automatically?

— Reply to this email directly, view it on GitHub https://github.com/google/jaxopt/issues/380#issuecomment-1554297577, or unsubscribe https://github.com/notifications/unsubscribe-auth/AACDZB5LS6WPDJW3J2VE6BLXG44RZANCNFSM6AAAAAAT62H7U4 . You are receiving this because you modified the open/close state.Message ID: @.***>

fabianp avatar May 19 '23 12:05 fabianp

Agreed!

mblondel avatar May 19 '23 12:05 mblondel

@froystig made a good point in private conversation, that this might be symptomatic of jaxopt not using the cache properly and/or generating too many fresh functions instead of re-using the cache.

I don't have the bandwidth to look into it right now, but leaving open in case someone can look into it more deeply

fabianp avatar May 22 '23 05:05 fabianp