apex
apex copied to clipboard
Use master weights for bfloat16 FusedAdam when master_weights=True
As mentioned in #1728, the FusedAdam optimizer ignores master_weights=True for bfloat16 parameters. This PR fixes that oversight. I have confirmed that the behavior now matches a "by hand" implementation of master weights (hand-copying) along with vanilla torch.optim.AdamW on the fp32 copy.
Ping @minitu, looks like you added this support originally -- could you take a look? Thanks
LGTM, we only looked at adding master weights for FP16 AMP at the time of the original PR. @crcrpar Could you review this as well?