StreamDiffusion icon indicating copy to clipboard operation
StreamDiffusion copied to clipboard

After VAE, result is NaN in sdxl mode.

Open ApolloRay opened this issue 1 year ago • 14 comments

I try streamdiffusion in sdxl model, but after VAE decode, result image is NaN. How can I solve it? I add added_cond_kwargs in unet_step. Or when streamdiffusion can support sdxl model !

ApolloRay avatar Jan 23 '24 11:01 ApolloRay

I have solved this problem, and streamdiffusion can support SDXL model. For this question, I have to change the vae dtype to fp32, otherwise the result will be overflow. OHOHOHOHOH!!!!!

ApolloRay avatar Jan 24 '24 11:01 ApolloRay

Thank you very much. Please feel free to submit a PR.

teftef6220 avatar Jan 28 '24 15:01 teftef6220

@ApolloRay Nice, can you share your method with us? And show us the txt2img speed ? Thanks.

bobby20180331 avatar Feb 06 '24 12:02 bobby20180331

@ApolloRay Nice, can you share your method with us? And show us the txt2img speed ? Thanks.

I will refine my code and release soon.

ApolloRay avatar Feb 07 '24 02:02 ApolloRay

Any progress on this? I'm trying to load SDXL by tampering with the code but I never worked with diffusers before. probably we'd need to replace stablediffusionpipeline calls to stablediffusionxlpipeline calls. and probably a bunch of other things?

or don't we?

any pointers would be appreciated at this point.

menguzat avatar Feb 12 '24 14:02 menguzat

@ApolloRay hi, did you succeed in making SDXL work with StreamingDiffusion? How's the performance?

Humanoidme avatar Feb 19 '24 22:02 Humanoidme

`import time from tkinter import X from typing import List, Optional, Union, Any, Dict, Tuple, Literal

import numpy as np import PIL.Image import torch from diffusers import LCMScheduler, StableDiffusionPipeline, StableDiffusionXLPipeline from diffusers.image_processor import VaeImageProcessor from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import ( retrieve_latents, )

from streamdiffusion.image_filter import SimilarImageFilter

class StreamDiffusion: def init( self, pipe: StableDiffusionXLPipeline, t_index_list: List[int], torch_dtype: torch.dtype = torch.float16, width: int = 1024, height: int = 1024, do_add_noise: bool = True, use_denoising_batch: bool = True, frame_buffer_size: int = 1, cfg_type: Literal["none", "full", "self", "initialize"] = "self", ) -> None: self.device = pipe.device self.dtype = torch_dtype self.generator = None

    self.height = height
    self.width = width

    self.latent_height = int(height // pipe.vae_scale_factor)
    self.latent_width = int(width // pipe.vae_scale_factor)

    self.frame_bff_size = frame_buffer_size
    self.denoising_steps_num = len(t_index_list)

    self.cfg_type = cfg_type

    if use_denoising_batch:
        self.batch_size = self.denoising_steps_num * frame_buffer_size
        if self.cfg_type == "initialize":
            self.trt_unet_batch_size = (
                self.denoising_steps_num + 1
            ) * self.frame_bff_size
        elif self.cfg_type == "full":
            self.trt_unet_batch_size = (
                2 * self.denoising_steps_num * self.frame_bff_size
            )
        else:
            self.trt_unet_batch_size = self.denoising_steps_num * frame_buffer_size
    else:
        self.trt_unet_batch_size = self.frame_bff_size
        self.batch_size = frame_buffer_size

    self.t_list = t_index_list

    self.do_add_noise = do_add_noise
    self.use_denoising_batch = use_denoising_batch

    self.similar_image_filter = False
    self.similar_filter = SimilarImageFilter()
    self.prev_image_result = None

    self.pipe = pipe
    self.image_processor = VaeImageProcessor(pipe.vae_scale_factor)

    self.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config)
    self.text_encoder = pipe.text_encoder
    self.unet = pipe.unet
    self.vae = pipe.vae
    
    self.inference_time_ema = 0

def load_lcm_lora(
    self,
    pretrained_model_name_or_path_or_dict: Union[
        str, Dict[str, torch.Tensor]
    ] = "latent-consistency/lcm-lora-sdv1-5",
    adapter_name: Optional[Any] = None,
    **kwargs,
) -> None:
    self.pipe.load_lora_weights(
        pretrained_model_name_or_path_or_dict, adapter_name, **kwargs
    )

def load_lora(
    self,
    pretrained_lora_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
    adapter_name: Optional[Any] = None,
    **kwargs,
) -> None:
    self.pipe.load_lora_weights(
        pretrained_lora_model_name_or_path_or_dict, adapter_name, **kwargs
    )

def fuse_lora(
    self,
    fuse_unet: bool = True,
    fuse_text_encoder: bool = True,
    lora_scale: float = 1.0,
    safe_fusing: bool = False,
) -> None:
    self.pipe.fuse_lora(
        fuse_unet=fuse_unet,
        fuse_text_encoder=fuse_text_encoder,
        lora_scale=lora_scale,
        safe_fusing=safe_fusing,
    )

def enable_similar_image_filter(self, threshold: float = 0.98, max_skip_frame: float = 10) -> None:
    self.similar_image_filter = True
    self.similar_filter.set_threshold(threshold)
    self.similar_filter.set_max_skip_frame(max_skip_frame)

def disable_similar_image_filter(self) -> None:
    self.similar_image_filter = False

@torch.no_grad()
def prepare(
    self,
    prompt: str,
    negative_prompt: str = "",
    num_inference_steps: int = 50,
    guidance_scale: float = 1.2,
    delta: float = 1.0,
    generator: Optional[torch.Generator] = torch.Generator(),
    seed: int = 2,
) -> None:
    self.generator = generator
    self.generator.manual_seed(seed)
    # initialize x_t_latent (it can be any random tensor)
    if self.denoising_steps_num > 1:
        self.x_t_latent_buffer = torch.zeros(
            (
                (self.denoising_steps_num - 1) * self.frame_bff_size,
                4,
                self.latent_height,
                self.latent_width,
            ),
            dtype=self.dtype,
            device=self.device,
        )
    else:
        self.x_t_latent_buffer = None

    if self.cfg_type == "none":
        self.guidance_scale = 1.0
    else:
        self.guidance_scale = guidance_scale
    self.delta = delta

    do_classifier_free_guidance = False
    if self.guidance_scale > 1.0:
        do_classifier_free_guidance = True

    encoder_output = self.pipe.encode_prompt(
        prompt=prompt,
        device=self.device,
        num_images_per_prompt=1,
        do_classifier_free_guidance=do_classifier_free_guidance,
        negative_prompt=negative_prompt,
    )
    self.prompt_embeds = encoder_output[0].repeat(self.batch_size, 1, 1)

    # ADD
    self.add_text_embeds = encoder_output[2]
    original_size = (self.height, self.width)
    crops_coords_top_left = (0, 0)
    target_size = (self.height, self.width)
    text_encoder_projection_dim = int(self.add_text_embeds.shape[-1])
    self.add_time_ids = self._get_add_time_ids(
        original_size,
        crops_coords_top_left,
        target_size,
        dtype=encoder_output[0].dtype,
        text_encoder_projection_dim=text_encoder_projection_dim,
    )

    if self.use_denoising_batch and self.cfg_type == "full":
        uncond_prompt_embeds = encoder_output[1].repeat(self.batch_size, 1, 1)
    elif self.cfg_type == "initialize":
        uncond_prompt_embeds = encoder_output[1].repeat(self.frame_bff_size, 1, 1)

    if self.guidance_scale > 1.0 and (
        self.cfg_type == "initialize" or self.cfg_type == "full"
    ):
        self.prompt_embeds = torch.cat(
            [uncond_prompt_embeds, self.prompt_embeds], dim=0
        )

    self.scheduler.set_timesteps(num_inference_steps, self.device)
    self.timesteps = self.scheduler.timesteps.to(self.device)

    # make sub timesteps list based on the indices in the t_list list and the values in the timesteps list
    self.sub_timesteps = []
    for t in self.t_list:
        self.sub_timesteps.append(self.timesteps[t])

    sub_timesteps_tensor = torch.tensor(
        self.sub_timesteps, dtype=torch.long, device=self.device
    )
    self.sub_timesteps_tensor = torch.repeat_interleave(
        sub_timesteps_tensor,
        repeats=self.frame_bff_size if self.use_denoising_batch else 1,
        dim=0,
    )

    self.init_noise = torch.randn(
        (self.batch_size, 4, self.latent_height, self.latent_width),
        generator=generator,
    ).to(device=self.device, dtype=self.dtype)

    self.stock_noise = torch.zeros_like(self.init_noise)

    c_skip_list = []
    c_out_list = []
    for timestep in self.sub_timesteps:
        c_skip, c_out = self.scheduler.get_scalings_for_boundary_condition_discrete(
            timestep
        )
        c_skip_list.append(c_skip)
        c_out_list.append(c_out)

    self.c_skip = (
        torch.stack(c_skip_list)
        .view(len(self.t_list), 1, 1, 1)
        .to(dtype=self.dtype, device=self.device)
    )
    self.c_out = (
        torch.stack(c_out_list)
        .view(len(self.t_list), 1, 1, 1)
        .to(dtype=self.dtype, device=self.device)
    )

    alpha_prod_t_sqrt_list = []
    beta_prod_t_sqrt_list = []
    for timestep in self.sub_timesteps:
        alpha_prod_t_sqrt = self.scheduler.alphas_cumprod[timestep].sqrt()
        beta_prod_t_sqrt = (1 - self.scheduler.alphas_cumprod[timestep]).sqrt()
        alpha_prod_t_sqrt_list.append(alpha_prod_t_sqrt)
        beta_prod_t_sqrt_list.append(beta_prod_t_sqrt)
    alpha_prod_t_sqrt = (
        torch.stack(alpha_prod_t_sqrt_list)
        .view(len(self.t_list), 1, 1, 1)
        .to(dtype=self.dtype, device=self.device)
    )
    beta_prod_t_sqrt = (
        torch.stack(beta_prod_t_sqrt_list)
        .view(len(self.t_list), 1, 1, 1)
        .to(dtype=self.dtype, device=self.device)
    )
    self.alpha_prod_t_sqrt = torch.repeat_interleave(
        alpha_prod_t_sqrt,
        repeats=self.frame_bff_size if self.use_denoising_batch else 1,
        dim=0,
    )
    self.beta_prod_t_sqrt = torch.repeat_interleave(
        beta_prod_t_sqrt,
        repeats=self.frame_bff_size if self.use_denoising_batch else 1,
        dim=0,
    )

@torch.no_grad()
def update_prompt(self, prompt: str) -> None:
    encoder_output = self.pipe.encode_prompt(
        prompt=prompt,
        device=self.device,
        num_images_per_prompt=1,
        do_classifier_free_guidance=False,
    )
    self.prompt_embeds = encoder_output[0].repeat(self.batch_size, 1, 1)

def add_noise(
    self,
    original_samples: torch.Tensor,
    noise: torch.Tensor,
    t_index: int,
) -> torch.Tensor:
    noisy_samples = (
        self.alpha_prod_t_sqrt[t_index] * original_samples
        + self.beta_prod_t_sqrt[t_index] * noise
    )
    return noisy_samples

def scheduler_step_batch(
    self,
    model_pred_batch: torch.Tensor,
    x_t_latent_batch: torch.Tensor,
    idx: Optional[int] = None,
) -> torch.Tensor:
    # TODO: use t_list to select beta_prod_t_sqrt
    if idx is None:
        F_theta = (
            x_t_latent_batch - self.beta_prod_t_sqrt * model_pred_batch
        ) / self.alpha_prod_t_sqrt
        denoised_batch = self.c_out * F_theta + self.c_skip * x_t_latent_batch
    else:
        F_theta = (
            x_t_latent_batch - self.beta_prod_t_sqrt[idx] * model_pred_batch
        ) / self.alpha_prod_t_sqrt[idx]
        denoised_batch = (
            self.c_out[idx] * F_theta + self.c_skip[idx] * x_t_latent_batch
        )

    return denoised_batch

def unet_step(
    self,
    x_t_latent: torch.Tensor,
    t_list: Union[torch.Tensor, list[int]],
    added_cond_kwargs,
    idx: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    if self.guidance_scale > 1.0 and (self.cfg_type == "initialize"):
        x_t_latent_plus_uc = torch.concat([x_t_latent[0:1], x_t_latent], dim=0)
        t_list = torch.concat([t_list[0:1], t_list], dim=0)
    elif self.guidance_scale > 1.0 and (self.cfg_type == "full"):
        x_t_latent_plus_uc = torch.concat([x_t_latent, x_t_latent], dim=0)
        t_list = torch.concat([t_list, t_list], dim=0)
    else:
        x_t_latent_plus_uc = x_t_latent
    model_pred = self.unet(
        x_t_latent_plus_uc,
        t_list,
        encoder_hidden_states=self.prompt_embeds,
        added_cond_kwargs=added_cond_kwargs,
        return_dict=False,
    )[0]
    if self.guidance_scale > 1.0 and (self.cfg_type == "initialize"):
        noise_pred_text = model_pred[1:]
        self.stock_noise = torch.concat(
            [model_pred[0:1], self.stock_noise[1:]], dim=0
        )  # ここコメントアウトでself out cfg
    elif self.guidance_scale > 1.0 and (self.cfg_type == "full"):
        noise_pred_uncond, noise_pred_text = model_pred.chunk(2)
    else:
        noise_pred_text = model_pred
    if self.guidance_scale > 1.0 and (
        self.cfg_type == "self" or self.cfg_type == "initialize"
    ):
        noise_pred_uncond = self.stock_noise * self.delta
    if self.guidance_scale > 1.0 and self.cfg_type != "none":
        model_pred = noise_pred_uncond + self.guidance_scale * (
            noise_pred_text - noise_pred_uncond
        )
    else:
        model_pred = noise_pred_text

    # compute the previous noisy sample x_t -> x_t-1
    if self.use_denoising_batch:
        denoised_batch = self.scheduler_step_batch(model_pred, x_t_latent, idx)
        if self.cfg_type == "self" or self.cfg_type == "initialize":
            scaled_noise = self.beta_prod_t_sqrt * self.stock_noise
            delta_x = self.scheduler_step_batch(model_pred, scaled_noise, idx)
            alpha_next = torch.concat(
                [
                    self.alpha_prod_t_sqrt[1:],
                    torch.ones_like(self.alpha_prod_t_sqrt[0:1]),
                ],
                dim=0,
            )
            delta_x = alpha_next * delta_x
            beta_next = torch.concat(
                [
                    self.beta_prod_t_sqrt[1:],
                    torch.ones_like(self.beta_prod_t_sqrt[0:1]),
                ],
                dim=0,
            )
            delta_x = delta_x / beta_next
            init_noise = torch.concat(
                [self.init_noise[1:], self.init_noise[0:1]], dim=0
            )
            self.stock_noise = init_noise + delta_x

    else:
        # denoised_batch = self.scheduler.step(model_pred, t_list[0], x_t_latent).denoised
        denoised_batch = self.scheduler_step_batch(model_pred, x_t_latent, idx)

    return denoised_batch, model_pred

# ADD
def _get_add_time_ids(
    self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
):
    add_time_ids = list(original_size + crops_coords_top_left + target_size)

    passed_add_embed_dim = (
        self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
    )
    expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features

    if expected_add_embed_dim != passed_add_embed_dim:
        raise ValueError(
            f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
        )

    add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
    return add_time_ids

def encode_image(self, image_tensors: torch.Tensor) -> torch.Tensor:
    image_tensors = image_tensors.to(
        device=self.device,
        dtype=self.vae.dtype,
    )
    img_latent = retrieve_latents(self.vae.encode(image_tensors), self.generator)
    img_latent = img_latent * self.vae.config.scaling_factor
    x_t_latent = self.add_noise(img_latent, self.init_noise[0], 0)
    return x_t_latent

def decode_image(self, x_0_pred_out: torch.Tensor) -> torch.Tensor:
    self.vae = self.vae.to(torch.float32)
    x_0_pred_out = x_0_pred_out.to(torch.float32)
    output_latent = self.vae.decode(
        x_0_pred_out / self.vae.config.scaling_factor, return_dict=False
    )[0]
    return output_latent

def predict_x0_batch(self, x_t_latent: torch.Tensor) -> torch.Tensor:
    prev_latent_batch = self.x_t_latent_buffer
    if self.use_denoising_batch:
        t_list = self.sub_timesteps_tensor
        if self.denoising_steps_num > 1:
            x_t_latent = torch.cat((x_t_latent, prev_latent_batch), dim=0)
            self.stock_noise = torch.cat(
                (self.init_noise[0:1], self.stock_noise[:-1]), dim=0
            )
        added_cond_kwargs = {"text_embeds": self.add_text_embeds.to(self.device), "time_ids": self.add_time_ids.to(self.device)}
        x_t_latent = x_t_latent.to(self.device)
        t_list = t_list.to(self.device)
        x_0_pred_batch, model_pred = self.unet_step(x_t_latent, t_list, added_cond_kwargs=added_cond_kwargs)
        if self.denoising_steps_num > 1:
            x_0_pred_out = x_0_pred_batch[-1].unsqueeze(0)
            if self.do_add_noise:
                self.x_t_latent_buffer = (
                    self.alpha_prod_t_sqrt[1:] * x_0_pred_batch[:-1]
                    + self.beta_prod_t_sqrt[1:] * self.init_noise[1:]
                )
            else:
                self.x_t_latent_buffer = (
                    self.alpha_prod_t_sqrt[1:] * x_0_pred_batch[:-1]
                )
        else:
            x_0_pred_out = x_0_pred_batch
            self.x_t_latent_buffer = None
    else:
        self.init_noise = x_t_latent
        for idx, t in enumerate(self.sub_timesteps_tensor):
            t = t.view(
                1,
            ).repeat(
                self.frame_bff_size,
            )
            added_cond_kwargs = {"text_embeds": self.add_text_embeds.to(self.device), "time_ids": self.add_time_ids.to(self.device)}
            x_0_pred_batch, model_pred = self.unet_step(x_t_latent, t_list, added_cond_kwargs=added_cond_kwargs)
            
            if idx < len(self.sub_timesteps_tensor) - 1:
                if self.do_add_noise:
                    x_t_latent = self.alpha_prod_t_sqrt[
                        idx + 1
                    ] * x_0_pred + self.beta_prod_t_sqrt[
                        idx + 1
                    ] * torch.randn_like(
                        x_0_pred, device=self.device, dtype=self.dtype
                    )
                else:
                    x_t_latent = self.alpha_prod_t_sqrt[idx + 1] * x_0_pred
        x_0_pred_out = x_0_pred
    return x_0_pred_out

@torch.no_grad()
def __call__(
    self, x: Union[torch.Tensor, PIL.Image.Image, np.ndarray] = None
) -> torch.Tensor:

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    if x is not None:
        x = self.image_processor.preprocess(x, self.height, self.width).to(
            device=self.device, dtype=self.dtype
        )
        if self.similar_image_filter:
            x = self.similar_filter(x)
            if x is None:
                time.sleep(self.inference_time_ema)
                return self.prev_image_result
        x_t_latent = self.encode_image(x)
    else:
        # TODO: check the dimension of x_t_latent
        x_t_latent = torch.randn((1, 4, self.latent_height, self.latent_width)).to(
            device=self.device, dtype=self.dtype
        )

    x_0_pred_out = self.predict_x0_batch(x_t_latent)
    x_output = self.decode_image(x_0_pred_out).detach().clone()

    self.prev_image_result = x_output
    end.record()
    torch.cuda.synchronize()
    inference_time = start.elapsed_time(end) / 1000
    self.inference_time_ema = 0.9 * self.inference_time_ema + 0.1 * inference_time
    return x_output

@torch.no_grad()
def txt2img(self, batch_size: int = 1) -> torch.Tensor:
    x_0_pred_out = self.predict_x0_batch(
        torch.randn((batch_size, 4, self.latent_height, self.latent_width)).to(
            device=self.device, dtype=self.dtype
        )
    )
    x_output = self.decode_image(x_0_pred_out).detach().clone()
    return x_output

def txt2img_sd_turbo(self, batch_size: int = 1) -> torch.Tensor:
    x_t_latent = torch.randn(
        (batch_size, 4, self.latent_height, self.latent_width),
        device=self.device,
        dtype=self.dtype,
    )
    model_pred = self.unet(
        x_t_latent,
        self.sub_timesteps_tensor,
        encoder_hidden_states=self.prompt_embeds,
        return_dict=False,
    )[0]
    x_0_pred_out = (
        x_t_latent - self.beta_prod_t_sqrt * model_pred
    ) / self.alpha_prod_t_sqrt
    return self.decode_image(x_0_pred_out)

` Replace and Try.

ApolloRay avatar Feb 20 '24 08:02 ApolloRay

#ADD represent for added code.

ApolloRay avatar Feb 20 '24 08:02 ApolloRay

Thank you for looking into this @ApolloRay - I did have to make one more tweak to get sdxl-turbo working (passing down the added_cond_kwargs). As a VAE, I swapped out madebyollin/taesd to madebyollin/taesdxl

wouterverweirder avatar Feb 21 '24 16:02 wouterverweirder

added_cond_kwargs

Hey there, I find myself in a bit of a loss as to how to pass down the correct arguments when calling the pipeline. could you please elaborate?

menguzat avatar Mar 05 '24 18:03 menguzat

figured it out.

you have to replace the stablediffusionpipeline lines in wrapper.py to stablediffusionxlpipeline.

furthermore, if you want to use fp16, you'll have to replace all float32s in pipeline, wrapper, your image transmission code and -possibly- image_utils.py

if you want to use fp32 too, i think you have to replace all these to fp32 (there are some that are fp16 in the repo).

other than that @ApolloRay kudos and thanks a lot!

menguzat avatar Mar 06 '24 01:03 menguzat

@ApolloRay Can you submit a PR? I'll be appreciate if you do that🥺🥺🥺 I can't wait to use the Ultra Fast SDXL model! But my try on your code is failed. I don't know what went wrong. THX!

Mars160 avatar Mar 12 '24 16:03 Mars160

Hey if anyone is really biting their nails on how to make this work. I currently have it powering the front page of https://pollinations.ai. I'm using the Dreamshaper Lightning XL model together with StreamDiffusion.

I made some unrelated tweaks in the code connected to pollinations but I could make a pull request out of it there is interest.

https://github.com/pollinations/pollinations/tree/master/image_gen/StreamDiffusion

Sorry for not making a proper fork. Can do later.

voodoohop avatar Mar 20 '24 16:03 voodoohop

I have added SDXL support to the pipeline and wrapper based on @ApolloRay 's code (thanks for that!). I have also added single image and real-time image generation examples for sdxl-turbo. Check out my sdxl branch.

https://github.com/hkn-g/StreamDiffusion/tree/sdxl

sdxl-turbo img2img looks fine, but there are some other issues, such as TensorRT not working, and text2img giving an error with an SDXL model. However, in img2img mode, it works without an input image.

hkn-g avatar Mar 20 '24 17:03 hkn-g