GaLore
GaLore copied to clipboard
`torch_run.py` lacking autocast and scaling for Automatic Mixed Precision
Hey,
As mentioned in the title, there is the direct conversion of the model to BF16, without the use of torch.amp
functions of autocast
and scaling
needed for AMP.
This means that the projected memory shown here is only the 2bytes for the model (BF16) but the results post-training would be bad as per various sources. Beyond that, we would need AMP for it to work properly, which means getting 6 bytes per parameter, which blows the 24GiB mentioned in the paper out of the water.
For LLaMa3 8B, you would need 8 * 10^9 * 6 bytes ~ 44GiB for just parameter loading in BF16 AMP.
Just wanted to point it out, and ask about why this is made this way. The paper also mentions a 58GiB minimum -- but I think you'd need much more than that.
If this is a deliberate decision, please point me to the studies that show that such training has been stabilized.
src: [ https://docs.fast.ai/callback.fp16.html ]