trl icon indicating copy to clipboard operation
trl copied to clipboard

memory usage of DPO trainer seems stepwise growing with time

Open Emerald01 opened this issue 1 year ago • 6 comments

Hi,

I am DPO training a checkpoint of Mixtral-8x7B-Instruct, from the previous supervised finetune.

I mainly followed this script https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py with 8 H100 GPUs, flash attn and deepspeed zero 2, everything looks good but I notice that the memory consumption has stepwise growing.

Any ideas why it is not a constant? Look like there is some memory garbage collection issues?

Screenshot 2024-02-27 at 8 00 52 PM

Emerald01 avatar Feb 28 '24 04:02 Emerald01

cc @kashif have you ever experienced with this?

younesbelkada avatar Feb 29 '24 01:02 younesbelkada

@Emerald01 could you share your zero2 config? do you use cpu offloading? I have the same problem as it goes out of memory after some steps with Mixtral. My env: 8 A-100 GPUS

saeedkhaki92 avatar Mar 04 '24 21:03 saeedkhaki92

Hi ! Can you try to clear the cuda cache between each training step? You could modify the DPOTrainer source code to overwrite def training_step() method: https://github.com/huggingface/transformers/blob/e9476832942a19cf99354776ef112babc83c139a/src/transformers/trainer.py#L2848 and call torch.cuda.empty_cache() after each step together with gc.collect()

younesbelkada avatar Mar 05 '24 00:03 younesbelkada

@younesbelkada that works!

Emerald01 avatar Mar 06 '24 16:03 Emerald01

oh nice! cc @muellerz do you know if Trainer properly handles torch.cuda.empty_cache() affter each training step? Perhaps worth making a PR on transformers side? Let me know if you want me to have a look as well

younesbelkada avatar Mar 11 '24 13:03 younesbelkada

@younesbelkada I believe transformer does not properly clear the cache after each training step, after your suggestion, I did the empty cache and gc collection, compared to the previous stepwise growing memory, now it becomes almost a constant

    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
        loss_step = super().training_step(model, inputs)
        torch.cuda.empty_cache()
        gc.collect()
        return loss_step

The following is the current GPU memory running the same script Screenshot 2024-03-06 at 9 30 09 AM

Emerald01 avatar Mar 11 '24 21:03 Emerald01

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

github-actions[bot] avatar Apr 05 '24 15:04 github-actions[bot]