why drop last 4 frames in wan-animate training?
https://github.com/modelscope/DiffSynth-Studio/blob/main/diffsynth/pipelines/wan_video_new.py#L1073
@parryppp In the original implementation of this model. The number of frames in the input video is not consistent with the output. We have to drop the last 4 frames to align the two videos; otherwise, we will see a "tensor shape mismatch error".
@Artiprocher Thanks first for your excellent open-source work on DiffSynth-Studio — it’s a fantastic project.
After analyzing the Wan-Animate training flow, I believe the “drop the last 4 frames” operation in WanVideoPostUnit_AnimateVideoSplit is only a downstream patch to avoid the dimension mismatch error. It does not actually resolve the root cause of the mismatch. The true source of the dimensional error lies in the upstream functions: WanVideoUnit_NoiseInitializer and WanVideoUnit_InputVideoEmbedder — which neglect to account for the input_image. As a result, while WanVideoPostUnit_AnimateVideoSplit avoids program crashes it introduces a much more serious training issue: the model is conditioned on the first 77 frames, yet asked to predict / learn all 81 frames of the video (assuming an input video with 81 frames).
Specifically:
-
In DiffSynth-Studio, for Wan-Animate, the time dimension of the input latents is determined in WanVideoUnit_NoiseInitializer. It initializes a noise tensor whose time length is
length = (num_frames - 1) // 4 + 1, where num_frames refers to the input video frame count. Whennum_frames = 81, this giveslength = 21. This computation only considersinput_video, ignoring theinput_image. As a result, the baseline time dimension is wrong. https://github.com/modelscope/DiffSynth-Studio/blob/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/diffsynth/pipelines/wan_video_new.py#L593-L606 -
Because the noise time dimension is fixed to 21, all subsequent tensors interacting with it must match that time dimension. In the training flow the input
xis constructed by concatenating along channel dimension thelatents(noise)andy. https://github.com/modelscope/DiffSynth-Studio/blob/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/diffsynth/pipelines/wan_video_new.py#L1483-L1485 - Fory, it is a concatenation in the time dimension ofref_pixel_values (from the input_image, time dim = 1)andy_reft (from animate_inpaint_video). Hencey_reftmust have time dimension = 21 - 1 = 20. Therefore, the sourceanimate_inpaint_videomust be truncated to a frame count L such that (L - 1) // 4 + 1 = 20, which for L = 77 is the truncation of last 4 frames from 81. https://github.com/modelscope/DiffSynth-Studio/blob/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/diffsynth/pipelines/wan_video_new.py#L1213-L1235 -
For
pose_latents, which is added intox[:, :, 1:], sincexhastime=21,x[:, :,1:]has timedimension=20, so pose_latents must also havetime=20— again forcing truncation of the last 4 frames of theanimate-pose-video. https://github.com/modelscope/DiffSynth-Studio/blob/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/diffsynth/models/wan_video_animate_adapter.py#L623-L625
The serious consequence: the conditioning tensor y only includes information from the first 77 frames (both inpaint and pose), whereas the target input_latents still encodes all 81 frames. Thus there is a severe mismatch between the conditioning information and the training target in the time dimension. https://github.com/modelscope/DiffSynth-Studio/blob/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/diffsynth/pipelines/wan_video_new.py#L617-L633
Possible Fixes:
- Modify WanVideoUnit_NoiseInitializer to change the way the length of the noise tensor is calculated :
class WanVideoUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(input_params=("height", "width", "num_frames", "seed", "rand_device", "vace_reference_image","animate_pose_video"))
def process(self, pipe: WanVideoPipeline, height, width, num_frames, seed, rand_device, vace_reference_image,animate_pose_video):
length = (num_frames - 1) // 4 + 1
...
...
if animate_pose_video is not None: # 或者更稳妥的判断方式
length += 1
shape = (1, pipe.vae.model.z_dim, length, height // pipe.vae.upsampling_factor, width // pipe.vae.upsampling_factor)
noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device)
if vace_reference_image is not None:
noise = torch.concat((noise[:, :, -f:], noise[:, :, :-f]), dim=2)
return {"noise": noise}
- Update
WanVideoUnit_InputVideoEmbedderso that theinput_image(reference frame) is concatenated intoinput_latent:
class WanVideoUnit_InputVideoEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "vace_reference_image","animate_pose_video","input_image"),
onload_model_names=("vae",)
)
def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, vace_reference_image,animate_pose_video,input_image):
if input_video is None:
return {"latents": noise}
pipe.load_models_to_device(["vae"])
input_video = pipe.preprocess_video(input_video)
input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
...
...
if input_image is not None and animate_pose_video is not None:
input_image = pipe.preprocess_video([input_image])
input_image_latents = pipe.vae.encode(input_image, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device)
input_latents = torch.concat([input_image_latents, input_latents], dim=2)
if pipe.scheduler.training:
return {"latents": noise, "input_latents": input_latents}
else:
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
return {"latents": latents}
- Finally remove the
WanVideoPostUnit_AnimateVideoSplit:
class WanVideoPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=None):
...
self.units = [
...
WanVideoUnit_NoiseInitializer(), # 生成噪声
...
WanVideoUnit_InputVideoEmbedder(),
...
# WanVideoPostUnit_AnimateVideoSplit(),
WanVideoPostUnit_AnimatePoseLatents(),
WanVideoPostUnit_AnimateFacePixelValues(),
WanVideoPostUnit_AnimateInpaint(),
...
]
@zwplus thank you for your reply. I'm a bit confused — after this change, does it still end up with a latent of 81-frame input video matching against a latent of 77-frame background video
@zwplus Hi! I’m also working with Wan Animate. I’m wondering if you’ve successfully tuned the model. I tried tuning but noticed significant color shifting in the first frame after several iterations (no color shift at the training start). Thank you!
@gugite We faced the same issue initially. However, after addressing the problems mentioned above, our training seems to be proceeding normally.
@zwplus thank you for your reply. I'm a bit confused — after this change, does it still end up with a latent of 81-frame input video matching against a latent of 77-frame background video
With this fix, we no longer drop the last 4 frames. Both the input video and the background video/pose will now correspond to the full 81 frames .