trl icon indicating copy to clipboard operation
trl copied to clipboard

Gradient accumulation yields worse results than the equivalent batch size

Open benjamin-marie opened this issue 1 year ago • 20 comments

I expected a training configuration with per_device_train_batch_size=1 and gradient_accumulation_steps=32 to yield the same (or similar) result to per_device_train_batch_size=32 and gradient_accumulation_steps=1 but that's not the case, the former is much worse. I ran several experiments with SmolLM-135M and Llama 3.2 1B, using always the same seed, and the results are consistent with this observation.

image image

Maybe I misunderstand something here?

My training code is in this Colab notebook. I ran this notebook to draw the learning curves above, restarting the notebook between each training to avoid OOM. Note that I have the same observations with Qwen2.

benjamin-marie avatar Oct 04 '24 15:10 benjamin-marie

Hey, this is expected behaviour. FSDP-1 only allows accumulation in 16-bit precision. This is not the case for FSDP-2 which allows accumulation in both 16-bit and 32-bit.

mayank31398 avatar Oct 06 '24 06:10 mayank31398

documentation for FSDP-1: Screenshot 2024-10-06 at 2 34 34 AM documentation for FSDP-2: Screenshot 2024-10-06 at 2 35 10 AM

mayank31398 avatar Oct 06 '24 06:10 mayank31398

Interesting, I didn't know this. But I don't think it matters, I would be surprised that TRL uses FSDP's reduce-scatter for single GPU training.

benjamin-marie avatar Oct 07 '24 05:10 benjamin-marie

Hi, thanks for reporting this. Can you share your system info and the code you use for training?

qgallouedec avatar Oct 07 '24 09:10 qgallouedec

Sure, it's all in the notebook I linked to in my first post. I ran this notebook on Colab with the A100.

benjamin-marie avatar Oct 07 '24 10:10 benjamin-marie

Someone tried in fp32 and it didnt help so it doesnt seem to be the reason

https://x.com/bnjmn_marie/status/1842464802636980564

teknium1 avatar Oct 10 '24 15:10 teknium1

Have you tried full/mixed precision AdamW optimiser?

vigneshwaran avatar Oct 11 '24 04:10 vigneshwaran

Yes:

image

This configuration uses fp32 and adamw_torch.

benjamin-marie avatar Oct 11 '24 05:10 benjamin-marie

Hi, is there any updates? Thanks!

fzyzcjy avatar Oct 15 '24 06:10 fzyzcjy

I'm writing up a report about this - I think I managed to fix it :) (Yes it is in fact a subtle bug!) - will tweet and post about it in like 8 - 10 hours!

danielhanchen avatar Oct 15 '24 09:10 danielhanchen

We have fixed the issue guys!

Tweet: https://twitter.com/UnslothAI/status/1846231235749990699 Blogpost: https://unsloth.ai/blog/gradient

shimmyshimmer avatar Oct 15 '24 16:10 shimmyshimmer

We have fixed the issue guys!

nice! feel like fixing it in TRL too?

geronimi73 avatar Oct 15 '24 17:10 geronimi73

We have fixed the issue guys!

nice! feel like fixing it in TRL too?

The Hugging Face team is already on it! :)

shimmyshimmer avatar Oct 15 '24 17:10 shimmyshimmer

(Somewhat, currently trying to reverse engineer a few ways you did it, you guys would be much faster at it I imagine if you want to beat us to it ;) As this is more than TRL, it's ground up transformers/Trainer tbh I think)

muellerzr avatar Oct 15 '24 17:10 muellerzr

:) Wrote a detailed tweet about it: https://x.com/danielhanchen/status/1846235913443262891 Also Reddit post: https://www.reddit.com/r/LocalLLaMA/comments/1g4ego7/llm_training_bug_fixes_gradient_accumulation_was/ Blog post: https://unsloth.ai/blog/gradient Also @shimmyshimmer is my brother!! :)

danielhanchen avatar Oct 15 '24 18:10 danielhanchen

Just as a fair warning, this will not be an immediate nor quick fix, since essentially this means every single model's calculation is off when doing output.loss, and every single model will need a custom variation of CrossEntropy (and other valid loss funcs) if you do not calculate the loss by hand.

We are working on figuring out the best solution.

muellerzr avatar Oct 15 '24 19:10 muellerzr

@danielhanchen from the blog The 2nd theory was there is in fact a bug in the loss calculation, which we find to be the case. this bug is specifically for CrossEntropy loss calculation in HF trl? This will not be an issue if someone is using say torch.nn.CrossEntropyLoss ?

nahidalam avatar Oct 15 '24 23:10 nahidalam

@muellerzr , i believe this only make sense padding based batch, for packing, there is no 0 / pad token in the batch, so avg cross entropy is consistent

huseinzol05 avatar Oct 16 '24 01:10 huseinzol05

@nahidalam Unfortunately this is not a HF native issue. The way gradient accumulation has been originally done in many packages even those that use Pytorch directly accidentally missed considering ignored tokens. Using CE Loss directly does not solve the issue since mean reduction does not work, and sum will cause the loss to be scaled incorrectly.

@huseinzol05 Packing is also affected albeit less so since some people also do training on completions so it'll also make the loss incorrect.

@muellerzr If you guys need any help on anything, ping me!

danielhanchen avatar Oct 16 '24 01:10 danielhanchen

Kudos @danielhanchen on the fix! Neat write-up as well! Back to the OP, I think the issue isn't with the trl library, but with the transformers library instead, because of how SFTTrainer extends Trainer, how the loss is calculated in Trainer's compute_loss, and how it is naively scaled by the number of steps here. I don't have a ton of context, but I imagine the more principled solution would be to fix it within the Trainer.compute_loss function, vs say having SFTTrainer override the compute_loss method. Happy to assist with the transformers fix if anyone from HF would like to take me up on it 😄

wongjingping avatar Oct 16 '24 04:10 wongjingping

Does DDP have the same issue? @danielhanchen

qingjianbuyi avatar Oct 28 '24 13:10 qingjianbuyi

Yes, ddp does we already have documented this + a fix is being put in (I also have an article talking about this more, tl;dr you can choose a slower option of gathering all of the inputs/counts, which causes a communication which generally isn't recommended so it's False by default)

muellerzr avatar Oct 28 '24 14:10 muellerzr

Should this be closed since it's fixed in transformers?

cc @qgallouedec @lewtun

burtenshaw avatar Nov 25 '24 12:11 burtenshaw

Right @burtenshaw. Closed by https://github.com/huggingface/transformers/pull/34198

qgallouedec avatar Nov 25 '24 20:11 qgallouedec

Screenshot 2024-11-27 at 07 53 31

Time to hit that "Close Issue" button @qgallouedec @burtenshaw! :) I thought the issue was open because of that!

pminervini avatar Nov 27 '24 06:11 pminervini

Oops

qgallouedec avatar Nov 27 '24 07:11 qgallouedec

@huseinzol05 Packing is also affected albeit less so since some people also do training on completions so it'll also make the loss incorrect.

For language modeling task, will this be a problem even if all samples in a batch have the exact same sequence length?

surprisedPikachu007 avatar Dec 06 '24 05:12 surprisedPikachu007