diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

[From Single File] support `from_single_file` method for `WanAnimateTransformer3DModel`

Open samadwar opened this issue 1 month ago • 9 comments

What does this PR do?

Added support to load checkpoints from a single file where some modifications were required to convert_wan_transformer_to_diffusers method for it to work with WanAnimateTransformer3DModel

best regards, Sam

Fixes # (issue)

Before submitting

  • [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [x] Did you read the contributor guideline?
  • [x] Did you read our philosophy doc (important for complex PRs)?
  • [ ] Was this discussed/approved via a GitHub issue or the forum? Please add a link to it if that's the case.
  • [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
  • [ ] Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

samadwar avatar Nov 21 '25 00:11 samadwar

Hi @samadwar do you have a single file version of Wan Animate we can use to test this PR?

DN6 avatar Nov 21 '25 07:11 DN6

Hi @samadwar do you have a single file version of Wan Animate we can use to test this PR?

yes, https://huggingface.co/QuantStack/Wan2.2-Animate-14B-GGUF/blob/main/Wan2.2-Animate-14B-Q4_K_M.gguf or any GGUF file in https://huggingface.co/QuantStack/Wan2.2-Animate-14B-GGUF

or any file from here https://huggingface.co/Kijai/WanVideo_comfy_fp8_scaled/tree/main/Wan22Animate

samadwar avatar Nov 21 '25 07:11 samadwar

Hi @samadwar, thanks for the PR! Would you be able to share an example of a code snippet which uses WanAnimateTransformer3DModel.from_single_file? I tried to test the PR using the following script:

import os

import torch

from diffusers import GGUFQuantizationConfig, WanAnimatePipeline, WanAnimateTransformer3DModel
from diffusers.utils import export_to_video, load_image, load_video

single_file_ckpt = "https://huggingface.co/QuantStack/Wan2.2-Animate-14B-GGUF/blob/main/Wan2.2-Animate-14B-Q4_K_M.gguf"
# single_file_ckpt = "https://huggingface.co/Kijai/WanVideo_comfy_fp8_scaled/blob/main/Wan22Animate/Wan2_2-Animate-14B_fp8_scaled_e4m3fn_KJ_v2.safetensors"
model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers"

device = "cuda:0"
dtype = torch.bfloat16
seed = 42

transformer_kwargs = {}
_, single_file_ext = os.path.splitext(single_file_ckpt)
if single_file_ext == ".gguf":
    quantization_config = GGUFQuantizationConfig(compute_dtype=dtype)
    transformer_kwargs["quantization_config"] = quantization_config

transformer = WanAnimateTransformer3DModel.from_single_file(
    single_file_ckpt,
    config=model_id,
    subfolder="transformer",
    **transformer_kwargs,
)

pipe = WanAnimatePipeline.from_pretrained(
    model_id,
    transformer=transformer,
    torch_dtype=dtype,
)
pipe.to(device)

image = load_image("/path/to/reference_image.png")
pose_video = load_video("/path/to/pose_video.mp4")
face_video = load_video("/path/to/face_video.mp4")

video = pipe(
    image=image,
    pose_video=pose_video,
    face_video=face_video,
    prompt="People in the video are doing actions.",
    height=720,
    width=1280,
    mode="animate",
    guidance_scale=1.0,
    num_inference_steps=20,
    generator=torch.Generator(device=device).manual_seed(seed),
    output_type="np",
).frames[0]

export_to_video(video, "wan_animate_single_file.mp4", fps=30)

Using a checkpoint from QuantStack/Wan2.2-Animate-14B-GGUF doesn't get any errors, but the generated samples seem to be just noise:

https://github.com/user-attachments/assets/026294e4-9f23-4ec2-a8f6-c9cb8eb4ae9b

If I instead try a checkpoint from Kijai/WanVideo_comfy_fp8_scaled, I get an OOM error on a A100 (80 GB VRAM) and a lot of keys in the model don't seem to be used (they mainly end in .scale_weight, so they might be the FP8 scaling parameters?).

dg845 avatar Nov 22 '25 07:11 dg845

Hi @samadwar, thanks for the PR! Would you be able to share an example of a code snippet which uses WanAnimateTransformer3DModel.from_single_file? I tried to test the PR using the following script:

import os

import torch

from diffusers import GGUFQuantizationConfig, WanAnimatePipeline, WanAnimateTransformer3DModel
from diffusers.utils import export_to_video, load_image, load_video

single_file_ckpt = "https://huggingface.co/QuantStack/Wan2.2-Animate-14B-GGUF/blob/main/Wan2.2-Animate-14B-Q4_K_M.gguf"
# single_file_ckpt = "https://huggingface.co/Kijai/WanVideo_comfy_fp8_scaled/blob/main/Wan22Animate/Wan2_2-Animate-14B_fp8_scaled_e4m3fn_KJ_v2.safetensors"
model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers"

device = "cuda:0"
dtype = torch.bfloat16
seed = 42

transformer_kwargs = {}
_, single_file_ext = os.path.splitext(single_file_ckpt)
if single_file_ext == ".gguf":
    quantization_config = GGUFQuantizationConfig(compute_dtype=dtype)
    transformer_kwargs["quantization_config"] = quantization_config

transformer = WanAnimateTransformer3DModel.from_single_file(
    single_file_ckpt,
    config=model_id,
    subfolder="transformer",
    **transformer_kwargs,
)

pipe = WanAnimatePipeline.from_pretrained(
    model_id,
    transformer=transformer,
    torch_dtype=dtype,
)
pipe.to(device)

image = load_image("/path/to/reference_image.png")
pose_video = load_video("/path/to/pose_video.mp4")
face_video = load_video("/path/to/face_video.mp4")

video = pipe(
    image=image,
    pose_video=pose_video,
    face_video=face_video,
    prompt="People in the video are doing actions.",
    height=720,
    width=1280,
    mode="animate",
    guidance_scale=1.0,
    num_inference_steps=20,
    generator=torch.Generator(device=device).manual_seed(seed),
    output_type="np",
).frames[0]

export_to_video(video, "wan_animate_single_file.mp4", fps=30)

Using a checkpoint from QuantStack/Wan2.2-Animate-14B-GGUF doesn't get any errors, but the generated samples seem to be just noise: wan_animate_single_file_gguf_20_step.mp4

If I instead try a checkpoint from Kijai/WanVideo_comfy_fp8_scaled, I get an OOM error on a A100 (80 GB VRAM) and a lot of keys in the model don't seem to be used (they mainly end in .scale_weight, so they might be the FP8 scaling parameters?).

Yeah, I am experiencing same issue today, I had it working before, I will check and get back to you.

For the GGUF I am using AWS ml.g6e.4xlarge that comes with 45 GB VRAM, I don't have access to more GPU VRAM to test fp8. but I guess one way to check is load the file in safetensor package and check the actual value of the weights if they match or not.

samadwar avatar Nov 22 '25 09:11 samadwar

@dg845 I fixed the issue, can you try now?

samadwar avatar Nov 22 '25 10:11 samadwar

Code I am using:

import torch
import numpy as np
from diffusers import AutoencoderKLWan, GGUFQuantizationConfig
from diffusers import WanAnimatePipeline, WanAnimateTransformer3DModel
from diffusers.utils import export_to_video, load_image, load_video
import os
from diffusers.utils import logging
from safetensors.torch import load_file

LoRA = True
device_cpu = torch.device("cpu")
device_gpu = torch.device("cuda")
original_model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers"
lora_model_id = "Kijai/WanVideo_comfy"
lora_model_path = "Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank64_bf16.safetensors"

print("Loading transformer ....")
transformer = WanAnimateTransformer3DModel.from_single_file(
    "https://huggingface.co/QuantStack/Wan2.2-Animate-14B-GGUF/blob/main/Wan2.2-Animate-14B-Q8_0.gguf",
    quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
    config=original_model_id,
    subfolder="transformer",
    torch_dtype=torch.bfloat16,
    offload_device="cpu",
    device=device_gpu
)
print("Transformer loaded successfully ....")

print("Loading pipeline ....")
pipe = WanAnimatePipeline.from_pretrained(
    original_model_id,
    transformer=transformer,
    torch_dtype=torch.bfloat16,
)

if LoRA:
    pipe.load_lora_weights(
        lora_model_id,
        weight_name=lora_model_path,
        adapter_name="lightning",
        offload_device="cpu",
        device=device_gpu
    )

pipe.enable_model_cpu_offload()
print("Pipeline loaded successfully ....")

# Load the character image
image = load_image(
     "Wan2.2/examples/wan_animate/animate/image.jpeg"
 )

# Load pose and face videos (preprocessed from reference video)
# Note: Videos should be preprocessed to extract pose keypoints and face features
# Refer to the Wan-Animate preprocessing documentation for details
pose_video = load_video("Wan2.2/examples/wan_animate/animate/process_results/src_pose.mp4")
face_video = load_video("Wan2.2/examples/wan_animate/animate/process_results/src_face.mp4")

# Calculate optimal dimensions based on VAE constraints
max_area = 1280 * 720
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))

prompt = "People in the video are doing actions."

# Animation mode: Animate the character with the motion from pose/face videos
print("Generating animation ....")
if LoRA:
    output = pipe(
        image=image,
        pose_video=pose_video,
        face_video=face_video,
        prompt=prompt,
        #  negative_prompt=negative_prompt,
        height=height,
        width=width,
        segment_frame_length=77,
        guidance_scale=1.0,
        prev_segment_conditioning_frames=1,  # refert_num in original code
        num_inference_steps=4,
        mode="animate",
    ).frames[0]
else:
    output = pipe(
        image=image,
        pose_video=pose_video,
        face_video=face_video,
        prompt=prompt,
        #  negative_prompt=negative_prompt,
        height=height,
        width=width,
        segment_frame_length=77,
        guidance_scale=1.0,
        prev_segment_conditioning_frames=1,  # refert_num in original code
        num_inference_steps=20,
        mode="animate",
    ).frames[0]
print("Exporting animation ....")
export_to_video(output, "output_animation__.mp4", fps=30)

https://github.com/user-attachments/assets/b2833bb7-fc89-4061-b7cb-ec431b964b04

samadwar avatar Nov 22 '25 11:11 samadwar

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@bot /style

DN6 avatar Dec 07 '25 14:12 DN6

Style bot fixed some files and pushed the changes.

github-actions[bot] avatar Dec 07 '25 14:12 github-actions[bot]