[BUG] Concern around mixed precision training where weights are in low precision
I noticed that in deepspeed, when training with fp16 and bf16, weights are set to the lower precision. I am wondering if there is any chance of making this optional. For both bf16 and fp16 there is the risk of having the weight change "dissapear" due to the low precision
This paper brought first brought the issue to my attention: https://arxiv.org/abs/2010.06192
Empirically, have found a lot of diffusion model training to have small gradient norms, often around 0.02 or so. In BF16 and possibly even fp16 it appears that this optimization step may not even register.
In fp16, more bytes are allocated to the mantissa so its less risky but still seems like a potential issue.
fetching the dtype of the optimizer states or model weights does show they are in the reduced precision but to make sure i also checked the gpu memory usage. The below is zero stage-1 training of SDXL
Additionally, I had previously mentioned here that the Deepspeed BERT training example suffers significant performance loss when running in bf16
wanted to link this one here too https://github.com/Lightning-AI/pytorch-lightning/issues/18016
found the solution sire?
Any solution yet?
@ethansmith2000, apologies for delayed response. Assuming this remains relevant, please see my comments below.
I noticed that in deepspeed, when training with fp16 and bf16, weights are set to the lower precision. I am wondering if there is any chance of making this optional.
Can you clarify your concerns? Mixed-precision with ZeRO involves forward/backward computation in lower-precision, and optimizer computation in fp32 precision. Is your expectation different?
fetching the dtype of the optimizer states or model weights does show they are in the reduced precision
Can you describe how you observed that optimizer state is in reduced precision? The ZeRO design is to keep the master weights and optimizer states in fp32 precision.
@SonicCodes and @Monohydroxides, FYI
@tjruwase
Can you describe how you observed that optimizer state is in reduced precision? The ZeRO design is to keep the master weights and optimizer states in fp32 precision.
I observed that the optimizer states are in fp32, and the forward/backward are in fp16/bf16. At the time, I believed the forward and backward processes could be carried out with the precision we desired—for example, keeping certain blocks’ parameters in FP32 while others were in FP16/BF16. However, this is currently not feasible.