overvalidated
overvalidated
In this code new TransfomerEngine Linear module is created but original weights are cloned and new module has param_dtype=torch.float32. It causes OOM when model is converted. https://github.com/huggingface/accelerate/blob/dcde1e93d09abea02a8e7f4a07a2c5734b87b60e/src/accelerate/utils/transformer_engine.py#L31-L41
Honestly, I lack knowledge to fix it myself and it still uses the same amount of memory.
I've tried to remove old parameters and avoid to create new tensors inside of Transfomer Engine (skip_weight_param_allocation), but may be some extra tensors are created inside of transformer engine as...
After this LLaMA still OOMs during loading. Probably the problem is deeper and I'm not sure now what causes it. Also the parameter is called `params_dtype`, not `dtype`.
Still the issue is not resolved. It simply will create tensors in fp16 and replace them with model's ones.
But how fp16 training works then? Shouldn't it cause OOM either?