maxtext icon indicating copy to clipboard operation
maxtext copied to clipboard

Question: Gradient Accumulation

Open thiagolaitz opened this issue 1 year ago • 4 comments

Hello, does it support gradient accumulation or microbatches like those in the T5X repository? I didn't find a parameter for this in base.yml, maybe I just didn't see it? Thank you!

thiagolaitz avatar Apr 19 '24 14:04 thiagolaitz

We don't support that out of the box. We've found that tuning LR to be smaller is a better approach.

What is your use case?

rwitten avatar Apr 22 '24 20:04 rwitten

I'm training bigger models than before, so I can't use the same batch size on the same TPU. Got any recommended ablation studies on using gradient accumulation versus lowering the LR? Also, if I skip gradient accumulation, should I just linearly reduce the LR based on the batch size? Thanks!

thiagolaitz avatar Apr 22 '24 21:04 thiagolaitz

+1 Adding another use case: considering that the availability of TPUs vary, we encounter situations where we initially train a model with a v4-128 TPU but later need to replicate the experiment with a v4-64 TPU, which has less memory. Thus, we must use gradient accumulation to maintain consistency in the results.

rodrigo-f-nogueira avatar Apr 25 '24 10:04 rodrigo-f-nogueira

Simply add following code after allocation of optimizer in optimizers.py support the gradient accumulation:

if config.accumulate_gradient_steps > 1:
    optimizer = optax.MultiSteps(optimizer, config.accumulate_gradient_steps)

hxssgaa avatar May 10 '24 09:05 hxssgaa

That solution if config.accumulate_gradient_steps > 1: optimizer = optax.MultiSteps(optimizer, config.accumulate_gradient_steps)

should work fine, however we have added gradient accumulation as a config which should have a slightly more accurate accumulation, weighting by number of unpadded tokens

gobbleturk avatar Aug 28 '24 02:08 gobbleturk

Great, thank you so much, @gobbleturk !

rodrigo-f-nogueira avatar Aug 28 '24 09:08 rodrigo-f-nogueira