accelerate icon indicating copy to clipboard operation
accelerate copied to clipboard

FP8 training causes OOM

Open overvalidated opened this issue 2 years ago • 14 comments

When FP8 is utilized model that is loaded in fp16 (llama) OOMs during training. Model works perfectly in fp16 mode. My guess is that autocast of model to TE layers changes dtype to float32.

overvalidated avatar May 13 '23 16:05 overvalidated

In this code new TransfomerEngine Linear module is created but original weights are cloned and new module has param_dtype=torch.float32. It causes OOM when model is converted. https://github.com/huggingface/accelerate/blob/dcde1e93d09abea02a8e7f4a07a2c5734b87b60e/src/accelerate/utils/transformer_engine.py#L31-L41

overvalidated avatar May 13 '23 16:05 overvalidated

Indeed, the linear layer needs to be created with the same dtype as the original one. Would you like to suggest a PR with a fix?

sgugger avatar May 16 '23 14:05 sgugger

Honestly, I lack knowledge to fix it myself and it still uses the same amount of memory.

overvalidated avatar May 16 '23 18:05 overvalidated

@sgugger I would like to work on this PR. I am new to open source contribution, so if I you can guide me on how to work on this PR, I would like to take it up.

avisinghal6 avatar May 20 '23 17:05 avisinghal6

I've tried to remove old parameters and avoid to create new tensors inside of Transfomer Engine (skip_weight_param_allocation), but may be some extra tensors are created inside of transformer engine as part of fp8 training. Not sure whether it causes OOM or I miss some tensors that should be cleaned.

overvalidated avatar May 20 '23 18:05 overvalidated

I was just planning to add this to the existing code because we just want to r Screenshot 2023-05-20 at 11 58 45 AM etain the datatype

avisinghal6 avatar May 20 '23 18:05 avisinghal6

After this LLaMA still OOMs during loading. Probably the problem is deeper and I'm not sure now what causes it. Also the parameter is called params_dtype, not dtype.

overvalidated avatar May 20 '23 20:05 overvalidated

Yes, you are right, the parameter is params_dtype (i just wanted to show the idea in the snippet). The OOM error might be due to other causes as you said, the above change is to fix the problem that converted layer should have same dtype as original. @sgugger kindly advise if I should create a PR for this change.

avisinghal6 avatar May 20 '23 21:05 avisinghal6

Yes, you can definitely open a PR with this fix.

sgugger avatar May 22 '23 13:05 sgugger

added PR #1467.

avisinghal6 avatar May 22 '23 16:05 avisinghal6

Still the issue is not resolved. It simply will create tensors in fp16 and replace them with model's ones.

overvalidated avatar May 22 '23 17:05 overvalidated

@overvalidated It is FP8 mixed precision training. The actual memory usage will be higher than in regular training since you ahve two copies of the model. One in FP8 and one in FP32 (or lower precision but that wouldn't really work for training).

sgugger avatar May 22 '23 17:05 sgugger

But how fp16 training works then? Shouldn't it cause OOM either?

overvalidated avatar May 22 '23 19:05 overvalidated

@overvalidated I have no code I can reproduce so I can't really explain what goes wrong for you.

sgugger avatar May 22 '23 19:05 sgugger

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Jun 16 '23 15:06 github-actions[bot]

Hi all, we finally narrowed down the two sources of leakage in the implementation that we could improve. #2089 will fix this, reducing your memory by a significant amount.

For example, loading in "meta-llama/Llama-2-7b-hf" in FP8 was near 38gb before, now it is only 12.61 GB (compared to 12.61 GB in bf16 where the weights were originally loaded)

muellerzr avatar Oct 26 '23 20:10 muellerzr