DiffSynth-Studio icon indicating copy to clipboard operation
DiffSynth-Studio copied to clipboard

why drop last 4 frames in wan-animate training?

Open parryppp opened this issue 1 month ago • 3 comments

https://github.com/modelscope/DiffSynth-Studio/blob/main/diffsynth/pipelines/wan_video_new.py#L1073

Image

parryppp avatar Nov 03 '25 04:11 parryppp

@parryppp In the original implementation of this model. The number of frames in the input video is not consistent with the output. We have to drop the last 4 frames to align the two videos; otherwise, we will see a "tensor shape mismatch error".

Artiprocher avatar Nov 04 '25 09:11 Artiprocher

@Artiprocher Thanks first for your excellent open-source work on DiffSynth-Studio — it’s a fantastic project. After analyzing the Wan-Animate training flow, I believe the “drop the last 4 frames” operation in WanVideoPostUnit_AnimateVideoSplit is only a downstream patch to avoid the dimension mismatch error. It does not actually resolve the root cause of the mismatch. The true source of the dimensional error lies in the upstream functions: WanVideoUnit_NoiseInitializer and WanVideoUnit_InputVideoEmbedder — which neglect to account for the input_image. As a result, while WanVideoPostUnit_AnimateVideoSplit avoids program crashes it introduces a much more serious training issue: the model is conditioned on the first 77 frames, yet asked to predict / learn all 81 frames of the video (assuming an input video with 81 frames).

Specifically:

  1. In DiffSynth-Studio, for Wan-Animate, the time dimension of the input latents is determined in WanVideoUnit_NoiseInitializer. It initializes a noise tensor whose time length is length = (num_frames - 1) // 4 + 1, where num_frames refers to the input video frame count. When num_frames = 81, this gives length = 21. This computation only considers input_video, ignoring the input_image. As a result, the baseline time dimension is wrong. https://github.com/modelscope/DiffSynth-Studio/blob/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/diffsynth/pipelines/wan_video_new.py#L593-L606

  2. Because the noise time dimension is fixed to 21, all subsequent tensors interacting with it must match that time dimension. In the training flow the input x is constructed by concatenating along channel dimension the latents(noise) and y. https://github.com/modelscope/DiffSynth-Studio/blob/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/diffsynth/pipelines/wan_video_new.py#L1483-L1485 - For y, it is a concatenation in the time dimension of ref_pixel_values (from the input_image, time dim = 1) and y_reft (from animate_inpaint_video). Hence y_reft must have time dimension = 21 - 1 = 20. Therefore, the source animate_inpaint_video must be truncated to a frame count L such that (L - 1) // 4 + 1 = 20, which for L = 77 is the truncation of last 4 frames from 81. https://github.com/modelscope/DiffSynth-Studio/blob/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/diffsynth/pipelines/wan_video_new.py#L1213-L1235

  3. For pose_latents, which is added into x[:, :, 1:], since x has time=21, x[:, :,1:] has time dimension=20, so pose_latents must also have time=20 — again forcing truncation of the last 4 frames of the animate-pose-video. https://github.com/modelscope/DiffSynth-Studio/blob/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/diffsynth/models/wan_video_animate_adapter.py#L623-L625

The serious consequence: the conditioning tensor y only includes information from the first 77 frames (both inpaint and pose), whereas the target input_latents still encodes all 81 frames. Thus there is a severe mismatch between the conditioning information and the training target in the time dimension. https://github.com/modelscope/DiffSynth-Studio/blob/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/diffsynth/pipelines/wan_video_new.py#L617-L633

Possible Fixes:

  1. Modify WanVideoUnit_NoiseInitializer to change the way the length of the noise tensor is calculated :
class WanVideoUnit_NoiseInitializer(PipelineUnit):
    def __init__(self):
        super().__init__(input_params=("height", "width", "num_frames", "seed", "rand_device", "vace_reference_image","animate_pose_video"))
    def process(self, pipe: WanVideoPipeline, height, width, num_frames, seed, rand_device, vace_reference_image,animate_pose_video):
        length = (num_frames - 1) // 4 + 1
        ...
        ...
        if animate_pose_video is not None: # 或者更稳妥的判断方式
            length += 1

        shape = (1, pipe.vae.model.z_dim, length, height // pipe.vae.upsampling_factor, width // pipe.vae.upsampling_factor)
        noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device)
        if vace_reference_image is not None:
            noise = torch.concat((noise[:, :, -f:], noise[:, :, :-f]), dim=2)
        return {"noise": noise}
  1. Update WanVideoUnit_InputVideoEmbedder so that the input_image (reference frame) is concatenated into input_latent
class WanVideoUnit_InputVideoEmbedder(PipelineUnit):
    def __init__(self):
        super().__init__(
            input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "vace_reference_image","animate_pose_video","input_image"),
            onload_model_names=("vae",)
        )

    def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, vace_reference_image,animate_pose_video,input_image):
        if input_video is None:
            return {"latents": noise}
        pipe.load_models_to_device(["vae"])
        input_video = pipe.preprocess_video(input_video)
        input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
        ...
        ...
        if input_image is not None and animate_pose_video is not None:
            input_image = pipe.preprocess_video([input_image])
            input_image_latents = pipe.vae.encode(input_image, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device)
            input_latents = torch.concat([input_image_latents, input_latents], dim=2)
        
        if pipe.scheduler.training:
            return {"latents": noise, "input_latents": input_latents}
        else:
            latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
            return {"latents": latents}
  1. Finally remove the WanVideoPostUnit_AnimateVideoSplit :
class WanVideoPipeline(BasePipeline):
    def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=None):
        ...
        self.units = [
            ...
            WanVideoUnit_NoiseInitializer(), # 生成噪声
            ...
            WanVideoUnit_InputVideoEmbedder(), 
            ...
            # WanVideoPostUnit_AnimateVideoSplit(),
            WanVideoPostUnit_AnimatePoseLatents(),
            WanVideoPostUnit_AnimateFacePixelValues(),
            WanVideoPostUnit_AnimateInpaint(),
            ...
        ]

zwplus avatar Nov 18 '25 09:11 zwplus

@zwplus thank you for your reply. I'm a bit confused — after this change, does it still end up with a latent of 81-frame input video matching against a latent of 77-frame background video

parryppp avatar Nov 26 '25 08:11 parryppp

@zwplus Hi! I’m also working with Wan Animate. I’m wondering if you’ve successfully tuned the model. I tried tuning but noticed significant color shifting in the first frame after several iterations (no color shift at the training start). Thank you!

gugite avatar Dec 11 '25 17:12 gugite

@gugite We faced the same issue initially. However, after addressing the problems mentioned above, our training seems to be proceeding normally.

zwplus avatar Dec 12 '25 14:12 zwplus

@zwplus thank you for your reply. I'm a bit confused — after this change, does it still end up with a latent of 81-frame input video matching against a latent of 77-frame background video

With this fix, we no longer drop the last 4 frames. Both the input video and the background video/pose will now correspond to the full 81 frames .

zwplus avatar Dec 12 '25 14:12 zwplus