[Bug] OOM when fine-tuning WAN 2.1-14B on 8x H200 (batch=1, V1 script)
Describe the bug
Using the V1 script to fine-tune WAN 2.1-14B on 8×H200 GPUs with a batch size of 1 and input data of shape 512×512×81 results in an out-of-memory (OOM) error. Is this expected? Additionally, does the current version support FP16 training?
Reproduction
.
Environment
.
For comparison, the hunyuanvideo-13B finetune script in V0 (e.g., finetune_hunyuan.sh) seems to support input shapes like 1280×720×125 by default. Moreover, I have successfully trained with the shape 512×512×121 and bf16 master weight type.
it should work, could you share the scripts you are using? Did you turn on SP?
it should work, could you share the scripts you are using? Did you turn on SP?
NUM_GPUS=8
torchrun --nnodes 1 --nproc_per_node $NUM_GPUS
fastvideo/v1/training/wan_training_pipeline.py
--model_path /root/highspeedstorage/Data01/weights/Wan2.1-T2V-14B-Diffusers
--inference_mode False
--pretrained_model_name_or_path /root/highspeedstorage/Data01/weights/Wan2.1-T2V-14B-Diffusers
--data_path "$DATA_DIR"
--validation_dataset_file "$VALIDATION_DATASET_FILE"
--train_batch_size 1
--sp_size 1
--tp_size 1
--num_gpus $NUM_GPUS
--hsdp_replicate_dim $NUM_GPUS
--hsdp_shard_dim 1
--train_sp_batch_size 1
--dataloader_num_workers 4
--gradient_accumulation_steps 1
--max_train_steps 30000
--learning_rate 1e-5
--mixed_precision "bf16"
--checkpointing_steps 6000
--validation_steps 100
--validation_sampling_steps "50"
--log_validation
--checkpoints_total_limit 30
--ema_start_step 0
--training_cfg_rate 0.0
--output_dir "/root/highspeedstorage/Data01/FastVideo/output/outputs/wan_14B_finetune"
--tracker_project_name wan_finetune
--num_height 512
--num_width 512
--num_frames 81
--validation_guidance_scale "1.0"
--num_euler_timesteps 50
--weight_decay 0.01
--not_apply_cfg_solver
--dit_precision "fp32"
--max_grad_norm 1.0
--enable_gradient_checkpointing_type "full"
--allow_tf32 \
Me either. I ran 480x832x97 and set sp_size=8 (H100x8 node) My script as follows:
#!/bin/bash
export WANDB_BASE_URL="https://api.wandb.ai"
export WANDB_MODE=online
# export FASTVIDEO_ATTENTION_BACKEND=TORCH_SDPA
height=480
width=832
num_frames=129
num_latent_t=$(( num_frames / 4 + 1 ))
version=v3.0
dataset_name=${version}_${height}x${width}x${num_frames}
MODEL_PATH="/datasets/video_generation/Wan2.1/Wan2.1-I2V-14B-480P-Diffusers"
DATA_DIR="/datasets/fastvideo_data_process/${dataset_name}"
VALIDATION_DIR="data/crush-smol_processed_i2v/validation_parquet_dataset/"
OUTPUT_DIR="/users/yjhong/fastvideo"
NUM_GPUS=8
# export CUDA_VISIBLE_DEVICES=4,5
# IP=[MASTER NODE IP]
# Training arguments
training_args=(
--tracker_project_name "wan_i2v_finetune"
--output_dir "$OUTPUT_DIR/outputs/wan_i2v_finetune_${dataset_name}"
--max_train_steps 2000
--train_batch_size 1
--train_sp_batch_size 1
--gradient_accumulation_steps 1
--num_latent_t 24
--num_height ${height}
--num_width ${width}
--num_frames ${num_frames}
)
# Parallel arguments
parallel_args=(
--num_gpus $NUM_GPUS
--sp_size 8
--tp_size 1
--hsdp_replicate_dim 1
--hsdp_shard_dim 8
)
# Model arguments
model_args=(
--model_path $MODEL_PATH
--pretrained_model_name_or_path $MODEL_PATH
)
# Dataset arguments
dataset_args=(
--data_path "$DATA_DIR"
--dataloader_num_workers 10
)
# Validation arguments
validation_args=(
--log_validation
--validation_preprocessed_path "$VALIDATION_DIR"
--validation_steps 100
--validation_sampling_steps "40"
--validation_guidance_scale "1.0"
)
# Optimizer arguments
optimizer_args=(
--learning_rate 1e-5
--mixed_precision "bf16"
--checkpointing_steps 1000
--weight_decay 1e-4
--max_grad_norm 1.0
)
# Miscellaneous arguments
miscellaneous_args=(
--inference_mode False
--allow_tf32
--checkpoints_total_limit 3
--training_cfg_rate 0.1
--multi_phased_distill_schedule "4000-1"
--not_apply_cfg_solver
--dit_precision "fp32"
--num_euler_timesteps 50
--ema_start_step 0
)
# If you do not have 32 GPUs and to fit in memory, you can: 1. increase sp_size. 2. reduce num_latent_t
torchrun \
--nnodes 1 \
--nproc_per_node $NUM_GPUS \
fastvideo/v1/training/wan_i2v_training_pipeline.py \
"${parallel_args[@]}" \
"${model_args[@]}" \
"${dataset_args[@]}" \
"${training_args[@]}" \
"${optimizer_args[@]}" \
"${validation_args[@]}" \
"${miscellaneous_args[@]}"
Could you try add --enable_gradient_checkpointing_type "full" (May need to pull first to ensure it support gradient_checkpointing
Yes it worked. missing that arguments.
Could you try add
--enable_gradient_checkpointing_type "full"(May need to pull first to ensure it support gradient_checkpointing您可以尝试添加--enable_gradient_checkpointing_type "full"吗(可能需要先拉取以确保它支持梯度检查点
Thanks for your reply! But --enable_gradient_checkpointing_type "full" has already in my script :(
@yiboz2001 For the training script, we use --num_latent_t instead of --num_frames to control how long each sample in the training data is. Meaning if you want to train with 81 frames, you need to use --num_latent_t 21. In addition, you'll want to use sp_size > 1 to enable sequence parallelism for 14B models if you still have OOM issues.
Sorry for the confusion, we will make sure to clearly document this.
have u solved the OOM problem