apex icon indicating copy to clipboard operation
apex copied to clipboard

Use master weights for bfloat16 FusedAdam when master_weights=True

Open cbcase opened this issue 2 years ago • 2 comments

As mentioned in #1728, the FusedAdam optimizer ignores master_weights=True for bfloat16 parameters. This PR fixes that oversight. I have confirmed that the behavior now matches a "by hand" implementation of master weights (hand-copying) along with vanilla torch.optim.AdamW on the fp32 copy.

cbcase avatar Sep 22 '23 17:09 cbcase

Ping @minitu, looks like you added this support originally -- could you take a look? Thanks

cbcase avatar Oct 16 '23 23:10 cbcase

LGTM, we only looked at adding master weights for FP16 AMP at the time of the original PR. @crcrpar Could you review this as well?

minitu avatar Oct 17 '23 17:10 minitu