apex icon indicating copy to clipboard operation
apex copied to clipboard

Fix bug with grad clipping and distributed Adam

Open timmoon10 opened this issue 2 years ago • 3 comments

I've gotten incorrect results using distributed Adam to train GPT-3 at FP16 because of a bug with gradient clipping and gradient scaling. In particular, there's an incorrect assumption that gradient clipping is applied before the gradients are unscaled, which results in extremely small gradients. The fix simply requires gradient clipping to handle the case where DistributedFusedAdam._grad_scale != 1.0. Note that torch.cuda.amp.GradScaler.unscale_ does not natively support the distributed optimizer, so I have to unscale by directly manipulating DistributedFusedAdam._grad_scale.

While I was touching the code, I've also adopted some recommendations from @erhoo82:

timmoon10 avatar Oct 14 '22 22:10 timmoon10

Note that torch.cuda.amp.GradScaler.unscale_ does not natively support the distributed optimizer, so I have to unscale by directly manipulating DistributedFusedAdam._grad_scale.

qq: would https://github.com/NVIDIA/apex/blob/master/apex/transformer/amp/grad_scaler.py be useful?

crcrpar avatar Oct 14 '22 23:10 crcrpar

Interesting, incorporating the distributed optimizer into that is a good idea. I'm using the NeMo grad scaler, which seems to be an extension of the Apex version.

timmoon10 avatar Oct 15 '22 00:10 timmoon10

apex's grad scaler there was just ported (copied) from NeMo per Sangkug's suggestion. also NeMo's looks based off of PyTorch's one: https://github.com/NVIDIA/NeMo/blob/18940b3b32cff290cf70d4a251b0e2f7b08e1525/nemo/collections/nlp/parts/nlp_overrides.py#L395 optimistically speaking the implementation wouldn't have been updated that much so I'm uncertain though it could help you separate concerns

crcrpar avatar Oct 15 '22 01:10 crcrpar

I've updated the Apex grad scaler to match the changes I've made in the NeMo grad scaler.

timmoon10 avatar Oct 20 '22 05:10 timmoon10