DiffSynth-Studio icon indicating copy to clipboard operation
DiffSynth-Studio copied to clipboard

Operation timeout when using Wan2.2-Fun-A14B-Control with multiple GPUs

Open FYRichie opened this issue 3 months ago • 2 comments

Hi,

I am grateful for this amazing repository for video conditioned Wan2.2s. I am trying to use Wan2.2-Fun-A14B-Control and give a depth video for generation. I tried on a single GPU (A6000 48G ram) and it leads to OOM, so I turned into using four A6000s. I encountered two timeouts:

  1. When downloading the weights: This is solved by manually download the weights and assign the downloaded path in ModelConfig
  2. When running pipeline: The video encoding works normally (it shows a progress bar VAE_encoding and run 24 steps). But the actual denoising step stuck at the first step (a progress bar with 50 denoising steps). The GPU utilization is 100% but the vram usage is extremely low (around 7,8 GB per GPU).

The command I used for running my code is:

CUDA_VISIBLE_DEVICES=1,2,3,4 torchrun --nproc_per_node=4 generate_video_wan.py

The code I run is as follows:

import torch
import torch.distributed as dist
from PIL import Image
import os
import json

from diffsynth import save_video, VideoData
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig

NEGATIVE_PROMPT = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"

if __name__ == "__main__":
    pipe = WanVideoPipeline.from_pretrained(
        torch_dtype=torch.bfloat16,
        device="cuda",
        use_usp=True,
        model_configs=[
            ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu", path="models/PAI/Wan2.2-Fun-A14B-Control/high_noise_model/diffusion_pytorch_model.safetensors"),
            ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu", path="models/PAI/Wan2.2-Fun-A14B-Control/low_noise_model/diffusion_pytorch_model.safetensors"),
            ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu", path="models/PAI/Wan2.2-Fun-A14B-Control/models_t5_umt5-xxl-enc-bf16.pth"),
            ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu", path="models/PAI/Wan2.2-Fun-A14B-Control/Wan2.1_VAE.pth"),
        ],
        tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*", path="models/Wan-AI/Wan2.1-T2V-1.3B/google/umt5-xxl"),
    )
    pipe.enable_vram_management()

    reference_image = Image.open("myimage.png")  # The size is 768 * 1024
    control_video = VideoData("myvideo.mp4", height=768, width=1024)  # There are 121 frames in this video
    prompt = "My prompt"

    video = pipe(
        prompt=prompt,
        negative_prompt=NEGATIVE_PROMPT,
        reference_image=reference_image,
        control_video=control_video,
        height=768,
        width=1024,
        num_frames=121,
        seed=1,
        tiled=True,
    )

    if dist.get_rank() == 0:
        save_video(video, os.path.join(SCENE, "video", f"{INSTANCE_ID:03d}_wan.mp4"), fps=24, quality=5)

I had looked over the issues posted in this repo, but seems that it is a new problem. Thank you for your time and I really wish to get respond from you.

FYRichie avatar Oct 28 '25 21:10 FYRichie

@FYRichie You can remove offload_device="cpu". This setting is for reducing VRAM required, but it may make the program slow.

Additionally, Multi-GPU inference requires NVLink, otherwise it will be slow. I don't know whether your GPU supports NVLink. Please check it.

Artiprocher avatar Oct 30 '25 06:10 Artiprocher

Hi @Artiprocher , thanks for your reply! My system will stuck (instead of running slow) when I run the code above. Meanwhile, I find that when I upgrade torch from old versions like 2.5.0 2.6.0 to 2.9.0, it solves the problem. Why is this related to torch version?

FYRichie avatar Nov 03 '25 17:11 FYRichie