support 14B full training?
When training a 14B-parameter model (DiT), encountering "CUDA Out Of Memory (OOM)" errors
@zsp1993 We are working on it.
@zsp1993 Done. Now you can use --use_gradient_checkpointing_offload and --training_strategy "deepspeed_stage_3" to train 14B T2V model using 8 A100 (8*80G VRAM) GPUs.
Thanks for your coding! I can full finetune the 14B model.
By the way, how should we load the models by diffusers? I cannot convert the model for deepspeed to the model for diffusers.
For example, the model for deepspeed is the following:
user@Trainer:/mnt/raid0/wan/lightning_logs/version_6/checkpoints/epoch=0-step=63.ckpt/checkpoint$ ls
bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt bf16_zero_pp_rank_4_mp_rank_00_optim_states.pt zero_pp_rank_0_mp_rank_00_model_states.pt zero_pp_rank_4_mp_rank_00_model_states.pt
bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt bf16_zero_pp_rank_5_mp_rank_00_optim_states.pt zero_pp_rank_1_mp_rank_00_model_states.pt zero_pp_rank_5_mp_rank_00_model_states.pt
bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt bf16_zero_pp_rank_6_mp_rank_00_optim_states.pt zero_pp_rank_2_mp_rank_00_model_states.pt zero_pp_rank_6_mp_rank_00_model_states.pt
bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt bf16_zero_pp_rank_7_mp_rank_00_optim_states.pt zero_pp_rank_3_mp_rank_00_model_states.pt zero_pp_rank_7_mp_rank_00_model_states.pt
@zsp1993 Done. Now you can use
--use_gradient_checkpointing_offloadand--training_strategy "deepspeed_stage_3"to train 14B T2V model using 8 A100 (8*80G VRAM) GPUs.
still oom when training i2v 14b wan2.1 model
@zsp1993完成。现在您可以使用 8 个 A100 (8*80G VRAM) GPU
--use_gradient_checkpointing_offload来--training_strategy "deepspeed_stage_3"训练 14B T2V 模型。训练i2v 14b wan2.1模型时还是oom
Hello, did you modify the script of training i2v by yourself?
@zsp1993 Done. Now you can use
--use_gradient_checkpointing_offloadand--training_strategy "deepspeed_stage_3"to train 14B T2V model using 8 A100 (8*80G VRAM) GPUs.
still oom when train 14b i2v model
@alfredplpl You can find a py script in the checkpoint folder. It is provided by the pytorch-lightning framework. Please run it.
@ucaswindlike @sihaowei-yw @lith0613 To be honest, the current solution of fine-tuning the full 14B T2V model is already at its limit. The 14B I2V model requires slightly more GPU memory, which we currently cannot achieve. Our tensor parallelism framework is still under development, and we will continue to optimize this feature.
@ucaswindlike @sihaowei-yw @lith0613 To be honest, the current solution of fine-tuning the full 14B T2V model is already at its limit. The 14B I2V model requires slightly more GPU memory, which we currently cannot achieve. Our tensor parallelism framework is still under development, and we will continue to optimize this feature.
thanks for your code. I find that after 2000step finetune, the result is all noise. Can you give me some suggestion?
@ucaswindlike @sihaowei-yw @lith0613 To be honest, the current solution of fine-tuning the full 14B T2V model is already at its limit. The 14B I2V model requires slightly more GPU memory, which we currently cannot achieve. Our tensor parallelism framework is still under development, and we will continue to optimize this feature.
8 nvidia4090 enough?
@mobilejammer No. 8*A100 GPUs are required.
I met some bugs "torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Recomputed values for the following tensors have different metadata than during the forward pass." when I set --use_gradient_checkpointing_offload and --training_strategy "deepspeed_stage_3" for the training of wan 14B fine-tuning. Could you kindly provide some hints with me to solve it ? or Would you work on the full-fine or lora training of wan 14B with 720P videos?
mark
@ucaswindlike @sihaowei-yw @lith0613 To be honest, the current solution of fine-tuning the full 14B T2V model is already at its limit. The 14B I2V model requires slightly more GPU memory, which we currently cannot achieve. Our tensor parallelism framework is still under development, and we will continue to optimize this feature.
thanks for your code. I find that after 2000step finetune, the result is all noise. Can you give me some suggestion?
I also have this problem, after fine-tune, the result is very bad