optax icon indicating copy to clipboard operation
optax copied to clipboard

optax.MultiSteps out of memory

Open ein-ich opened this issue 2 years ago • 15 comments

I always get an out of memory error using optax.MultiSteps, even when every_k_schedule=1. Using optax.apply_every(k=1) in a chain works fine.

optimizer = optax.chain(
    optax.clip_by_global_norm(0.5),
    optax.adam(lr),
    #optax.apply_every(k=1)
)
optimizer = optax.MultiSteps(optimizer, every_k_schedule=1)

Later I'm using opt_state = optimizer.init(params) and

updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)

I have no idea what I could be doing wrong. I'm not changing anything else, like batch size.

ein-ich avatar Jan 07 '23 12:01 ein-ich

Hi! Interesting - thanks for reporting this!

Are you also at more than ~2/3 memory usage when you use apply_every? From a first look, I could see that the implementation of apply_every returns 0*updates for skipped steps while MultiSteps constructs a new array of 0s (even if every_k_schedule=1) so the former has a better memory footprint. This would explain a higher memory usage by up to 50% - but not more.

I'm not sure why the two functions use completely different code paths - we should be able to merge them (and deprecate one of them).

mkunesch avatar Jan 08 '23 18:01 mkunesch

I have most of my available memory preallocated by JAX. I tried reducing the batch size from 120 (which works with apply_every) to 30, but it still crashed with MultiSteps.

ein-ich avatar Jan 08 '23 19:01 ein-ich

I am training Llama 2 7B on TPU. Without optax.MultiSteps my batch_size can be 4. However, after applying optax.MultiSteps, I got OOM even if batch_size is 1.

ayaka14732 avatar Jul 22 '23 11:07 ayaka14732

I can confirm that MultiStep implementation has much larger memory overhead than just one extra buffer for gradient (something like 4x extra buffers). This is very problematic when using this class with large models.

hr0nix avatar Aug 21 '23 20:08 hr0nix

I also noticed this issue

Sea-Snell avatar Aug 28 '23 23:08 Sea-Snell

I am having this issue as well for use in diffusion models

philippe-eecs avatar Aug 28 '23 23:08 philippe-eecs

Facing the same issue.

agrimgupta92 avatar Aug 29 '23 20:08 agrimgupta92

Hi everyone, thanks for flagging it up. I just merged a new version of optax.MultiSteps which should be more memory friendly, could you check this please?

hbq1 avatar Aug 30 '23 18:08 hbq1

you're a king

philippe-eecs avatar Aug 30 '23 19:08 philippe-eecs

Hi @hbq1! Thank you for the fix!

One question, I am still seeing a larger consumption with MultiStep when compared with the function apply_every. This was supposed to happen?

celiolarcher avatar Oct 06 '23 00:10 celiolarcher

As a follow-up, I was conducting some debugging by myself and it seems that the problem is on this part of the code (line 414):

new_updates, new_state = jax.lax.cond(
          state.mini_step < k_steps - 1,
          _mid_step, _final_step, *(state, params, acc_grads))

If I got it right, JAX is allocating memory for both function outputs (_mid_step and _final_step), so this basically doubles the space to store optimizer states and grads.

Still trying to figure out a way to solve it, though.

celiolarcher avatar Oct 06 '23 15:10 celiolarcher

Just added a PR merging apply_every logic into MultiStep function. From my initial tests, it reduces the memory footprint (able to train Llama2 7b in a v3-8 now) without affecting convergence.

celiolarcher avatar Oct 10 '23 13:10 celiolarcher

This is really great!

mtthss avatar Oct 31 '23 11:10 mtthss

Awesome work @celiolarcher!

jax.lax.cond seems to be suboptimal in some use cases, e.g. here, in theory, it should understand that either _mid_step or _final_step needs to be executed, so it shouldn't allocate memory for both outputs. It might be something that JAX/XLA devs would like to have a look at. Let me know if you'd like me to file a bug to https://github.com/google/jax/issues, or feel free to do it yourself ofc!

hbq1 avatar Nov 23 '23 09:11 hbq1

I'm glad to be able to help! About the issue @hbq1 , I can open it there, no problem.

celiolarcher avatar Dec 01 '23 11:12 celiolarcher