generative-models icon indicating copy to clipboard operation
generative-models copied to clipboard

Question about VAE. Getting Help!!!

Open liujf69 opened this issue 7 months ago • 0 comments

When encoding and decoding with VAE, the output is distorted. Is this normal? Or is there a problem with my code?

import torch
import imageio
import numpy as np
from einops import rearrange
from decord import VideoReader
from torchvision import transforms
from diffusers.models import AutoencoderKLTemporalDecoder

# load video
def load_video(video_path: str, len_frames: int = 16): 
    video_reader = VideoReader(video_path)
    video_length = len(video_reader)
    frame_idx = np.linspace(0, video_length - 1, len_frames, dtype = int)
    video = torch.from_numpy(video_reader.get_batch(frame_idx).asnumpy()).permute(0, 3, 1, 2).contiguous()
    video = video / 255.
    return video # T C H W

if __name__ == "__main__":
    device = "cuda"
    # pretrained_model_path = "stabilityai/stable-video-diffusion-img2vid-xt"
    pretrained_model_path = "./models/stable-video-diffusion-img2vid-xt" # local path
    vae = AutoencoderKLTemporalDecoder.from_pretrained(pretrained_model_path, subfolder = "vae")
    vae.to(device)

    video_path = "./test.mp4"
    video_tensor = load_video(video_path = video_path) # T C H W
    video_tensor = video_tensor.unsqueeze(0).to(device) # B T C H W

    with torch.no_grad():
        # encode
        pixel_values = rearrange(video_tensor, "b f c h w -> (b f) c h w") # b t c h w -> bt c h w
        latents = vae.encode(pixel_values).latent_dist
        latents = latents.sample()
        latents = 0.18215 * latents

        # decode
        latents = 1 / 0.18215 * latents
        frames = vae.decode(latents, num_frames = video_tensor.shape[1]).sample
        # frames = frames.float()

    save_path = "./test.gif"
    save_frames = []
    for idx in range(frames.shape[0]): # T
        save_frames.append(transforms.ToPILImage()(frames[idx].cpu()))
    imageio.mimsave(save_path, save_frames, loop = 0)
    print("All Done!")

origin_video: image output_gif: image

liujf69 avatar Jul 19 '24 02:07 liujf69