FP8 training causes OOM
When FP8 is utilized model that is loaded in fp16 (llama) OOMs during training. Model works perfectly in fp16 mode. My guess is that autocast of model to TE layers changes dtype to float32.
In this code new TransfomerEngine Linear module is created but original weights are cloned and new module has param_dtype=torch.float32. It causes OOM when model is converted. https://github.com/huggingface/accelerate/blob/dcde1e93d09abea02a8e7f4a07a2c5734b87b60e/src/accelerate/utils/transformer_engine.py#L31-L41
Indeed, the linear layer needs to be created with the same dtype as the original one. Would you like to suggest a PR with a fix?
Honestly, I lack knowledge to fix it myself and it still uses the same amount of memory.
@sgugger I would like to work on this PR. I am new to open source contribution, so if I you can guide me on how to work on this PR, I would like to take it up.
I've tried to remove old parameters and avoid to create new tensors inside of Transfomer Engine (skip_weight_param_allocation), but may be some extra tensors are created inside of transformer engine as part of fp8 training. Not sure whether it causes OOM or I miss some tensors that should be cleaned.
I was just planning to add this to the existing code because we just want to r
etain the datatype
After this LLaMA still OOMs during loading. Probably the problem is deeper and I'm not sure now what causes it.
Also the parameter is called params_dtype, not dtype.
Yes, you are right, the parameter is params_dtype (i just wanted to show the idea in the snippet). The OOM error might be due to other causes as you said, the above change is to fix the problem that converted layer should have same dtype as original. @sgugger kindly advise if I should create a PR for this change.
Yes, you can definitely open a PR with this fix.
added PR #1467.
Still the issue is not resolved. It simply will create tensors in fp16 and replace them with model's ones.
@overvalidated It is FP8 mixed precision training. The actual memory usage will be higher than in regular training since you ahve two copies of the model. One in FP8 and one in FP32 (or lower precision but that wouldn't really work for training).
But how fp16 training works then? Shouldn't it cause OOM either?
@overvalidated I have no code I can reproduce so I can't really explain what goes wrong for you.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Hi all, we finally narrowed down the two sources of leakage in the implementation that we could improve. #2089 will fix this, reducing your memory by a significant amount.
For example, loading in "meta-llama/Llama-2-7b-hf" in FP8 was near 38gb before, now it is only 12.61 GB (compared to 12.61 GB in bf16 where the weights were originally loaded)