LWM icon indicating copy to clipboard operation
LWM copied to clipboard

Why always use float32 precision in training?

Open cizhenshi opened this issue 1 year ago • 1 comments

Is it necessary to use float32 in training? Why not use the widely used bf16 type, which saves more gpu memory. Looking forward to your reply, Thansk!

cizhenshi avatar Feb 20 '24 07:02 cizhenshi

Hi there! So from what I know, bf16 can be beneficial while training large models or models with limited GPU memory.However, it comes with a trade-off in relation to precision;precision for these models can be significantly low as reduced precision may affect model convergence and performance, particularly in tasks that require fine-grained numerical accuracy.

We could suggest the creator of this model to go with mixed precision training. This combines the numerical precision of float32 and the memory savings of the bf16, offering a balance between efficiency and accuracy.

Another reason could be that the creator's hardware and software stack may not be necessarily supporting bf16 trained data points.

Hope this answers your query!

Mrinal96 avatar Feb 20 '24 08:02 Mrinal96

In general, we didn't run into too much memory bottlenecks for our needs, so we primarily just stuck with fp32 to be safe (proper mixed precision training with bf16 requires more careful adjustments for which parts of the network need to be computed in single / half precision).

Additionally, we trained on TPUs, which by default always do bf16 computation for matmuls under the hood (even if both inputs are fp32), so there's not too much of a speed difference between all fp32 and mixed precision (maybe ~10% slower with all fp32).

wilson1yan avatar Feb 21 '24 20:02 wilson1yan