DiffSynth-Studio icon indicating copy to clipboard operation
DiffSynth-Studio copied to clipboard

support 14B full training?

Open shapera-lab opened this issue 10 months ago • 14 comments

When training a 14B-parameter model (DiT), encountering "CUDA Out Of Memory (OOM)" errors

shapera-lab avatar Mar 06 '25 09:03 shapera-lab

@zsp1993 We are working on it.

Artiprocher avatar Mar 10 '25 03:03 Artiprocher

@zsp1993 Done. Now you can use --use_gradient_checkpointing_offload and --training_strategy "deepspeed_stage_3" to train 14B T2V model using 8 A100 (8*80G VRAM) GPUs.

Artiprocher avatar Mar 10 '25 10:03 Artiprocher

Thanks for your coding! I can full finetune the 14B model.

By the way, how should we load the models by diffusers? I cannot convert the model for deepspeed to the model for diffusers.

For example, the model for deepspeed is the following:

user@Trainer:/mnt/raid0/wan/lightning_logs/version_6/checkpoints/epoch=0-step=63.ckpt/checkpoint$ ls
bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt  bf16_zero_pp_rank_4_mp_rank_00_optim_states.pt  zero_pp_rank_0_mp_rank_00_model_states.pt  zero_pp_rank_4_mp_rank_00_model_states.pt
bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt  bf16_zero_pp_rank_5_mp_rank_00_optim_states.pt  zero_pp_rank_1_mp_rank_00_model_states.pt  zero_pp_rank_5_mp_rank_00_model_states.pt
bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt  bf16_zero_pp_rank_6_mp_rank_00_optim_states.pt  zero_pp_rank_2_mp_rank_00_model_states.pt  zero_pp_rank_6_mp_rank_00_model_states.pt
bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt  bf16_zero_pp_rank_7_mp_rank_00_optim_states.pt  zero_pp_rank_3_mp_rank_00_model_states.pt  zero_pp_rank_7_mp_rank_00_model_states.pt

alfredplpl avatar Mar 11 '25 08:03 alfredplpl

@zsp1993 Done. Now you can use --use_gradient_checkpointing_offload and --training_strategy "deepspeed_stage_3" to train 14B T2V model using 8 A100 (8*80G VRAM) GPUs.

still oom when training i2v 14b wan2.1 model

ucaswindlike avatar Mar 11 '25 11:03 ucaswindlike

@zsp1993完成。现在您可以使用 8 个 A100 (8*80G VRAM) GPU--use_gradient_checkpointing_offload--training_strategy "deepspeed_stage_3"训练 14B T2V 模型。

训练i2v 14b wan2.1模型时还是oom

Hello, did you modify the script of training i2v by yourself?

sihaowei-yw avatar Mar 12 '25 03:03 sihaowei-yw

@zsp1993 Done. Now you can use --use_gradient_checkpointing_offload and --training_strategy "deepspeed_stage_3" to train 14B T2V model using 8 A100 (8*80G VRAM) GPUs.

still oom when train 14b i2v model

lith0613 avatar Mar 12 '25 04:03 lith0613

@alfredplpl You can find a py script in the checkpoint folder. It is provided by the pytorch-lightning framework. Please run it.

Artiprocher avatar Mar 14 '25 02:03 Artiprocher

@ucaswindlike @sihaowei-yw @lith0613 To be honest, the current solution of fine-tuning the full 14B T2V model is already at its limit. The 14B I2V model requires slightly more GPU memory, which we currently cannot achieve. Our tensor parallelism framework is still under development, and we will continue to optimize this feature.

Artiprocher avatar Mar 14 '25 02:03 Artiprocher

@ucaswindlike @sihaowei-yw @lith0613 To be honest, the current solution of fine-tuning the full 14B T2V model is already at its limit. The 14B I2V model requires slightly more GPU memory, which we currently cannot achieve. Our tensor parallelism framework is still under development, and we will continue to optimize this feature.

thanks for your code. I find that after 2000step finetune, the result is all noise. Can you give me some suggestion?

renrenzsbbb avatar Mar 16 '25 15:03 renrenzsbbb

@ucaswindlike @sihaowei-yw @lith0613 To be honest, the current solution of fine-tuning the full 14B T2V model is already at its limit. The 14B I2V model requires slightly more GPU memory, which we currently cannot achieve. Our tensor parallelism framework is still under development, and we will continue to optimize this feature.

8 nvidia4090 enough?

mobilejammer avatar Mar 21 '25 03:03 mobilejammer

@mobilejammer No. 8*A100 GPUs are required.

Artiprocher avatar Mar 24 '25 03:03 Artiprocher

I met some bugs "torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Recomputed values for the following tensors have different metadata than during the forward pass." when I set --use_gradient_checkpointing_offload and --training_strategy "deepspeed_stage_3" for the training of wan 14B fine-tuning. Could you kindly provide some hints with me to solve it ? or Would you work on the full-fine or lora training of wan 14B with 720P videos?

TumCCC avatar Mar 25 '25 07:03 TumCCC

mark

animemory avatar Apr 08 '25 06:04 animemory

@ucaswindlike @sihaowei-yw @lith0613 To be honest, the current solution of fine-tuning the full 14B T2V model is already at its limit. The 14B I2V model requires slightly more GPU memory, which we currently cannot achieve. Our tensor parallelism framework is still under development, and we will continue to optimize this feature.

thanks for your code. I find that after 2000step finetune, the result is all noise. Can you give me some suggestion?

I also have this problem, after fine-tune, the result is very bad

huangjch526 avatar May 02 '25 15:05 huangjch526