optax
optax copied to clipboard
optax.MultiSteps out of memory
I always get an out of memory error using optax.MultiSteps, even when every_k_schedule=1. Using optax.apply_every(k=1) in a chain works fine.
optimizer = optax.chain(
optax.clip_by_global_norm(0.5),
optax.adam(lr),
#optax.apply_every(k=1)
)
optimizer = optax.MultiSteps(optimizer, every_k_schedule=1)
Later I'm using
opt_state = optimizer.init(params)
and
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
I have no idea what I could be doing wrong. I'm not changing anything else, like batch size.
Hi! Interesting - thanks for reporting this!
Are you also at more than ~2/3 memory usage when you use apply_every
? From a first look, I could see that the implementation of apply_every
returns 0*updates
for skipped steps while MultiSteps constructs a new array of 0s (even if every_k_schedule=1
) so the former has a better memory footprint. This would explain a higher memory usage by up to 50% - but not more.
I'm not sure why the two functions use completely different code paths - we should be able to merge them (and deprecate one of them).
I have most of my available memory preallocated by JAX. I tried reducing the batch size from 120 (which works with apply_every) to 30, but it still crashed with MultiSteps.
I am training Llama 2 7B on TPU. Without optax.MultiSteps
my batch_size
can be 4. However, after applying optax.MultiSteps
, I got OOM even if batch_size
is 1.
I can confirm that MultiStep implementation has much larger memory overhead than just one extra buffer for gradient (something like 4x extra buffers). This is very problematic when using this class with large models.
I also noticed this issue
I am having this issue as well for use in diffusion models
Facing the same issue.
Hi everyone, thanks for flagging it up. I just merged a new version of optax.MultiSteps
which should be more memory friendly, could you check this please?
you're a king
Hi @hbq1! Thank you for the fix!
One question, I am still seeing a larger consumption with MultiStep when compared with the function apply_every. This was supposed to happen?
As a follow-up, I was conducting some debugging by myself and it seems that the problem is on this part of the code (line 414):
new_updates, new_state = jax.lax.cond(
state.mini_step < k_steps - 1,
_mid_step, _final_step, *(state, params, acc_grads))
If I got it right, JAX is allocating memory for both function outputs (_mid_step and _final_step), so this basically doubles the space to store optimizer states and grads.
Still trying to figure out a way to solve it, though.
Just added a PR merging apply_every logic into MultiStep function. From my initial tests, it reduces the memory footprint (able to train Llama2 7b in a v3-8 now) without affecting convergence.
This is really great!
Awesome work @celiolarcher!
jax.lax.cond
seems to be suboptimal in some use cases, e.g. here, in theory, it should understand that either _mid_step
or _final_step
needs to be executed, so it shouldn't allocate memory for both outputs. It might be something that JAX/XLA devs would like to have a look at. Let me know if you'd like me to file a bug to https://github.com/google/jax/issues, or feel free to do it yourself ofc!
I'm glad to be able to help! About the issue @hbq1 , I can open it there, no problem.