StreamDiffusion
StreamDiffusion copied to clipboard
After VAE, result is NaN in sdxl mode.
I try streamdiffusion in sdxl model, but after VAE decode, result image is NaN. How can I solve it? I add added_cond_kwargs in unet_step. Or when streamdiffusion can support sdxl model !
I have solved this problem, and streamdiffusion can support SDXL model. For this question, I have to change the vae dtype to fp32, otherwise the result will be overflow. OHOHOHOHOH!!!!!
Thank you very much. Please feel free to submit a PR.
@ApolloRay Nice, can you share your method with us? And show us the txt2img speed ? Thanks.
@ApolloRay Nice, can you share your method with us? And show us the txt2img speed ? Thanks.
I will refine my code and release soon.
Any progress on this? I'm trying to load SDXL by tampering with the code but I never worked with diffusers before. probably we'd need to replace stablediffusionpipeline calls to stablediffusionxlpipeline calls. and probably a bunch of other things?
or don't we?
any pointers would be appreciated at this point.
@ApolloRay hi, did you succeed in making SDXL work with StreamingDiffusion? How's the performance?
`import time from tkinter import X from typing import List, Optional, Union, Any, Dict, Tuple, Literal
import numpy as np import PIL.Image import torch from diffusers import LCMScheduler, StableDiffusionPipeline, StableDiffusionXLPipeline from diffusers.image_processor import VaeImageProcessor from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import ( retrieve_latents, )
from streamdiffusion.image_filter import SimilarImageFilter
class StreamDiffusion: def init( self, pipe: StableDiffusionXLPipeline, t_index_list: List[int], torch_dtype: torch.dtype = torch.float16, width: int = 1024, height: int = 1024, do_add_noise: bool = True, use_denoising_batch: bool = True, frame_buffer_size: int = 1, cfg_type: Literal["none", "full", "self", "initialize"] = "self", ) -> None: self.device = pipe.device self.dtype = torch_dtype self.generator = None
self.height = height
self.width = width
self.latent_height = int(height // pipe.vae_scale_factor)
self.latent_width = int(width // pipe.vae_scale_factor)
self.frame_bff_size = frame_buffer_size
self.denoising_steps_num = len(t_index_list)
self.cfg_type = cfg_type
if use_denoising_batch:
self.batch_size = self.denoising_steps_num * frame_buffer_size
if self.cfg_type == "initialize":
self.trt_unet_batch_size = (
self.denoising_steps_num + 1
) * self.frame_bff_size
elif self.cfg_type == "full":
self.trt_unet_batch_size = (
2 * self.denoising_steps_num * self.frame_bff_size
)
else:
self.trt_unet_batch_size = self.denoising_steps_num * frame_buffer_size
else:
self.trt_unet_batch_size = self.frame_bff_size
self.batch_size = frame_buffer_size
self.t_list = t_index_list
self.do_add_noise = do_add_noise
self.use_denoising_batch = use_denoising_batch
self.similar_image_filter = False
self.similar_filter = SimilarImageFilter()
self.prev_image_result = None
self.pipe = pipe
self.image_processor = VaeImageProcessor(pipe.vae_scale_factor)
self.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config)
self.text_encoder = pipe.text_encoder
self.unet = pipe.unet
self.vae = pipe.vae
self.inference_time_ema = 0
def load_lcm_lora(
self,
pretrained_model_name_or_path_or_dict: Union[
str, Dict[str, torch.Tensor]
] = "latent-consistency/lcm-lora-sdv1-5",
adapter_name: Optional[Any] = None,
**kwargs,
) -> None:
self.pipe.load_lora_weights(
pretrained_model_name_or_path_or_dict, adapter_name, **kwargs
)
def load_lora(
self,
pretrained_lora_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[Any] = None,
**kwargs,
) -> None:
self.pipe.load_lora_weights(
pretrained_lora_model_name_or_path_or_dict, adapter_name, **kwargs
)
def fuse_lora(
self,
fuse_unet: bool = True,
fuse_text_encoder: bool = True,
lora_scale: float = 1.0,
safe_fusing: bool = False,
) -> None:
self.pipe.fuse_lora(
fuse_unet=fuse_unet,
fuse_text_encoder=fuse_text_encoder,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
)
def enable_similar_image_filter(self, threshold: float = 0.98, max_skip_frame: float = 10) -> None:
self.similar_image_filter = True
self.similar_filter.set_threshold(threshold)
self.similar_filter.set_max_skip_frame(max_skip_frame)
def disable_similar_image_filter(self) -> None:
self.similar_image_filter = False
@torch.no_grad()
def prepare(
self,
prompt: str,
negative_prompt: str = "",
num_inference_steps: int = 50,
guidance_scale: float = 1.2,
delta: float = 1.0,
generator: Optional[torch.Generator] = torch.Generator(),
seed: int = 2,
) -> None:
self.generator = generator
self.generator.manual_seed(seed)
# initialize x_t_latent (it can be any random tensor)
if self.denoising_steps_num > 1:
self.x_t_latent_buffer = torch.zeros(
(
(self.denoising_steps_num - 1) * self.frame_bff_size,
4,
self.latent_height,
self.latent_width,
),
dtype=self.dtype,
device=self.device,
)
else:
self.x_t_latent_buffer = None
if self.cfg_type == "none":
self.guidance_scale = 1.0
else:
self.guidance_scale = guidance_scale
self.delta = delta
do_classifier_free_guidance = False
if self.guidance_scale > 1.0:
do_classifier_free_guidance = True
encoder_output = self.pipe.encode_prompt(
prompt=prompt,
device=self.device,
num_images_per_prompt=1,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
)
self.prompt_embeds = encoder_output[0].repeat(self.batch_size, 1, 1)
# ADD
self.add_text_embeds = encoder_output[2]
original_size = (self.height, self.width)
crops_coords_top_left = (0, 0)
target_size = (self.height, self.width)
text_encoder_projection_dim = int(self.add_text_embeds.shape[-1])
self.add_time_ids = self._get_add_time_ids(
original_size,
crops_coords_top_left,
target_size,
dtype=encoder_output[0].dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
)
if self.use_denoising_batch and self.cfg_type == "full":
uncond_prompt_embeds = encoder_output[1].repeat(self.batch_size, 1, 1)
elif self.cfg_type == "initialize":
uncond_prompt_embeds = encoder_output[1].repeat(self.frame_bff_size, 1, 1)
if self.guidance_scale > 1.0 and (
self.cfg_type == "initialize" or self.cfg_type == "full"
):
self.prompt_embeds = torch.cat(
[uncond_prompt_embeds, self.prompt_embeds], dim=0
)
self.scheduler.set_timesteps(num_inference_steps, self.device)
self.timesteps = self.scheduler.timesteps.to(self.device)
# make sub timesteps list based on the indices in the t_list list and the values in the timesteps list
self.sub_timesteps = []
for t in self.t_list:
self.sub_timesteps.append(self.timesteps[t])
sub_timesteps_tensor = torch.tensor(
self.sub_timesteps, dtype=torch.long, device=self.device
)
self.sub_timesteps_tensor = torch.repeat_interleave(
sub_timesteps_tensor,
repeats=self.frame_bff_size if self.use_denoising_batch else 1,
dim=0,
)
self.init_noise = torch.randn(
(self.batch_size, 4, self.latent_height, self.latent_width),
generator=generator,
).to(device=self.device, dtype=self.dtype)
self.stock_noise = torch.zeros_like(self.init_noise)
c_skip_list = []
c_out_list = []
for timestep in self.sub_timesteps:
c_skip, c_out = self.scheduler.get_scalings_for_boundary_condition_discrete(
timestep
)
c_skip_list.append(c_skip)
c_out_list.append(c_out)
self.c_skip = (
torch.stack(c_skip_list)
.view(len(self.t_list), 1, 1, 1)
.to(dtype=self.dtype, device=self.device)
)
self.c_out = (
torch.stack(c_out_list)
.view(len(self.t_list), 1, 1, 1)
.to(dtype=self.dtype, device=self.device)
)
alpha_prod_t_sqrt_list = []
beta_prod_t_sqrt_list = []
for timestep in self.sub_timesteps:
alpha_prod_t_sqrt = self.scheduler.alphas_cumprod[timestep].sqrt()
beta_prod_t_sqrt = (1 - self.scheduler.alphas_cumprod[timestep]).sqrt()
alpha_prod_t_sqrt_list.append(alpha_prod_t_sqrt)
beta_prod_t_sqrt_list.append(beta_prod_t_sqrt)
alpha_prod_t_sqrt = (
torch.stack(alpha_prod_t_sqrt_list)
.view(len(self.t_list), 1, 1, 1)
.to(dtype=self.dtype, device=self.device)
)
beta_prod_t_sqrt = (
torch.stack(beta_prod_t_sqrt_list)
.view(len(self.t_list), 1, 1, 1)
.to(dtype=self.dtype, device=self.device)
)
self.alpha_prod_t_sqrt = torch.repeat_interleave(
alpha_prod_t_sqrt,
repeats=self.frame_bff_size if self.use_denoising_batch else 1,
dim=0,
)
self.beta_prod_t_sqrt = torch.repeat_interleave(
beta_prod_t_sqrt,
repeats=self.frame_bff_size if self.use_denoising_batch else 1,
dim=0,
)
@torch.no_grad()
def update_prompt(self, prompt: str) -> None:
encoder_output = self.pipe.encode_prompt(
prompt=prompt,
device=self.device,
num_images_per_prompt=1,
do_classifier_free_guidance=False,
)
self.prompt_embeds = encoder_output[0].repeat(self.batch_size, 1, 1)
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
t_index: int,
) -> torch.Tensor:
noisy_samples = (
self.alpha_prod_t_sqrt[t_index] * original_samples
+ self.beta_prod_t_sqrt[t_index] * noise
)
return noisy_samples
def scheduler_step_batch(
self,
model_pred_batch: torch.Tensor,
x_t_latent_batch: torch.Tensor,
idx: Optional[int] = None,
) -> torch.Tensor:
# TODO: use t_list to select beta_prod_t_sqrt
if idx is None:
F_theta = (
x_t_latent_batch - self.beta_prod_t_sqrt * model_pred_batch
) / self.alpha_prod_t_sqrt
denoised_batch = self.c_out * F_theta + self.c_skip * x_t_latent_batch
else:
F_theta = (
x_t_latent_batch - self.beta_prod_t_sqrt[idx] * model_pred_batch
) / self.alpha_prod_t_sqrt[idx]
denoised_batch = (
self.c_out[idx] * F_theta + self.c_skip[idx] * x_t_latent_batch
)
return denoised_batch
def unet_step(
self,
x_t_latent: torch.Tensor,
t_list: Union[torch.Tensor, list[int]],
added_cond_kwargs,
idx: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.guidance_scale > 1.0 and (self.cfg_type == "initialize"):
x_t_latent_plus_uc = torch.concat([x_t_latent[0:1], x_t_latent], dim=0)
t_list = torch.concat([t_list[0:1], t_list], dim=0)
elif self.guidance_scale > 1.0 and (self.cfg_type == "full"):
x_t_latent_plus_uc = torch.concat([x_t_latent, x_t_latent], dim=0)
t_list = torch.concat([t_list, t_list], dim=0)
else:
x_t_latent_plus_uc = x_t_latent
model_pred = self.unet(
x_t_latent_plus_uc,
t_list,
encoder_hidden_states=self.prompt_embeds,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
if self.guidance_scale > 1.0 and (self.cfg_type == "initialize"):
noise_pred_text = model_pred[1:]
self.stock_noise = torch.concat(
[model_pred[0:1], self.stock_noise[1:]], dim=0
) # ここコメントアウトでself out cfg
elif self.guidance_scale > 1.0 and (self.cfg_type == "full"):
noise_pred_uncond, noise_pred_text = model_pred.chunk(2)
else:
noise_pred_text = model_pred
if self.guidance_scale > 1.0 and (
self.cfg_type == "self" or self.cfg_type == "initialize"
):
noise_pred_uncond = self.stock_noise * self.delta
if self.guidance_scale > 1.0 and self.cfg_type != "none":
model_pred = noise_pred_uncond + self.guidance_scale * (
noise_pred_text - noise_pred_uncond
)
else:
model_pred = noise_pred_text
# compute the previous noisy sample x_t -> x_t-1
if self.use_denoising_batch:
denoised_batch = self.scheduler_step_batch(model_pred, x_t_latent, idx)
if self.cfg_type == "self" or self.cfg_type == "initialize":
scaled_noise = self.beta_prod_t_sqrt * self.stock_noise
delta_x = self.scheduler_step_batch(model_pred, scaled_noise, idx)
alpha_next = torch.concat(
[
self.alpha_prod_t_sqrt[1:],
torch.ones_like(self.alpha_prod_t_sqrt[0:1]),
],
dim=0,
)
delta_x = alpha_next * delta_x
beta_next = torch.concat(
[
self.beta_prod_t_sqrt[1:],
torch.ones_like(self.beta_prod_t_sqrt[0:1]),
],
dim=0,
)
delta_x = delta_x / beta_next
init_noise = torch.concat(
[self.init_noise[1:], self.init_noise[0:1]], dim=0
)
self.stock_noise = init_noise + delta_x
else:
# denoised_batch = self.scheduler.step(model_pred, t_list[0], x_t_latent).denoised
denoised_batch = self.scheduler_step_batch(model_pred, x_t_latent, idx)
return denoised_batch, model_pred
# ADD
def _get_add_time_ids(
self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
):
add_time_ids = list(original_size + crops_coords_top_left + target_size)
passed_add_embed_dim = (
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
)
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
if expected_add_embed_dim != passed_add_embed_dim:
raise ValueError(
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
)
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
return add_time_ids
def encode_image(self, image_tensors: torch.Tensor) -> torch.Tensor:
image_tensors = image_tensors.to(
device=self.device,
dtype=self.vae.dtype,
)
img_latent = retrieve_latents(self.vae.encode(image_tensors), self.generator)
img_latent = img_latent * self.vae.config.scaling_factor
x_t_latent = self.add_noise(img_latent, self.init_noise[0], 0)
return x_t_latent
def decode_image(self, x_0_pred_out: torch.Tensor) -> torch.Tensor:
self.vae = self.vae.to(torch.float32)
x_0_pred_out = x_0_pred_out.to(torch.float32)
output_latent = self.vae.decode(
x_0_pred_out / self.vae.config.scaling_factor, return_dict=False
)[0]
return output_latent
def predict_x0_batch(self, x_t_latent: torch.Tensor) -> torch.Tensor:
prev_latent_batch = self.x_t_latent_buffer
if self.use_denoising_batch:
t_list = self.sub_timesteps_tensor
if self.denoising_steps_num > 1:
x_t_latent = torch.cat((x_t_latent, prev_latent_batch), dim=0)
self.stock_noise = torch.cat(
(self.init_noise[0:1], self.stock_noise[:-1]), dim=0
)
added_cond_kwargs = {"text_embeds": self.add_text_embeds.to(self.device), "time_ids": self.add_time_ids.to(self.device)}
x_t_latent = x_t_latent.to(self.device)
t_list = t_list.to(self.device)
x_0_pred_batch, model_pred = self.unet_step(x_t_latent, t_list, added_cond_kwargs=added_cond_kwargs)
if self.denoising_steps_num > 1:
x_0_pred_out = x_0_pred_batch[-1].unsqueeze(0)
if self.do_add_noise:
self.x_t_latent_buffer = (
self.alpha_prod_t_sqrt[1:] * x_0_pred_batch[:-1]
+ self.beta_prod_t_sqrt[1:] * self.init_noise[1:]
)
else:
self.x_t_latent_buffer = (
self.alpha_prod_t_sqrt[1:] * x_0_pred_batch[:-1]
)
else:
x_0_pred_out = x_0_pred_batch
self.x_t_latent_buffer = None
else:
self.init_noise = x_t_latent
for idx, t in enumerate(self.sub_timesteps_tensor):
t = t.view(
1,
).repeat(
self.frame_bff_size,
)
added_cond_kwargs = {"text_embeds": self.add_text_embeds.to(self.device), "time_ids": self.add_time_ids.to(self.device)}
x_0_pred_batch, model_pred = self.unet_step(x_t_latent, t_list, added_cond_kwargs=added_cond_kwargs)
if idx < len(self.sub_timesteps_tensor) - 1:
if self.do_add_noise:
x_t_latent = self.alpha_prod_t_sqrt[
idx + 1
] * x_0_pred + self.beta_prod_t_sqrt[
idx + 1
] * torch.randn_like(
x_0_pred, device=self.device, dtype=self.dtype
)
else:
x_t_latent = self.alpha_prod_t_sqrt[idx + 1] * x_0_pred
x_0_pred_out = x_0_pred
return x_0_pred_out
@torch.no_grad()
def __call__(
self, x: Union[torch.Tensor, PIL.Image.Image, np.ndarray] = None
) -> torch.Tensor:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
if x is not None:
x = self.image_processor.preprocess(x, self.height, self.width).to(
device=self.device, dtype=self.dtype
)
if self.similar_image_filter:
x = self.similar_filter(x)
if x is None:
time.sleep(self.inference_time_ema)
return self.prev_image_result
x_t_latent = self.encode_image(x)
else:
# TODO: check the dimension of x_t_latent
x_t_latent = torch.randn((1, 4, self.latent_height, self.latent_width)).to(
device=self.device, dtype=self.dtype
)
x_0_pred_out = self.predict_x0_batch(x_t_latent)
x_output = self.decode_image(x_0_pred_out).detach().clone()
self.prev_image_result = x_output
end.record()
torch.cuda.synchronize()
inference_time = start.elapsed_time(end) / 1000
self.inference_time_ema = 0.9 * self.inference_time_ema + 0.1 * inference_time
return x_output
@torch.no_grad()
def txt2img(self, batch_size: int = 1) -> torch.Tensor:
x_0_pred_out = self.predict_x0_batch(
torch.randn((batch_size, 4, self.latent_height, self.latent_width)).to(
device=self.device, dtype=self.dtype
)
)
x_output = self.decode_image(x_0_pred_out).detach().clone()
return x_output
def txt2img_sd_turbo(self, batch_size: int = 1) -> torch.Tensor:
x_t_latent = torch.randn(
(batch_size, 4, self.latent_height, self.latent_width),
device=self.device,
dtype=self.dtype,
)
model_pred = self.unet(
x_t_latent,
self.sub_timesteps_tensor,
encoder_hidden_states=self.prompt_embeds,
return_dict=False,
)[0]
x_0_pred_out = (
x_t_latent - self.beta_prod_t_sqrt * model_pred
) / self.alpha_prod_t_sqrt
return self.decode_image(x_0_pred_out)
` Replace and Try.
#ADD represent for added code.
Thank you for looking into this @ApolloRay - I did have to make one more tweak to get sdxl-turbo working (passing down the added_cond_kwargs). As a VAE, I swapped out madebyollin/taesd to madebyollin/taesdxl
added_cond_kwargs
Hey there, I find myself in a bit of a loss as to how to pass down the correct arguments when calling the pipeline. could you please elaborate?
figured it out.
you have to replace the stablediffusionpipeline lines in wrapper.py to stablediffusionxlpipeline.
furthermore, if you want to use fp16, you'll have to replace all float32s in pipeline, wrapper, your image transmission code and -possibly- image_utils.py
if you want to use fp32 too, i think you have to replace all these to fp32 (there are some that are fp16 in the repo).
other than that @ApolloRay kudos and thanks a lot!
@ApolloRay Can you submit a PR? I'll be appreciate if you do that🥺🥺🥺 I can't wait to use the Ultra Fast SDXL model! But my try on your code is failed. I don't know what went wrong. THX!
Hey if anyone is really biting their nails on how to make this work. I currently have it powering the front page of https://pollinations.ai. I'm using the Dreamshaper Lightning XL model together with StreamDiffusion.
I made some unrelated tweaks in the code connected to pollinations but I could make a pull request out of it there is interest.
https://github.com/pollinations/pollinations/tree/master/image_gen/StreamDiffusion
Sorry for not making a proper fork. Can do later.
I have added SDXL support to the pipeline and wrapper based on @ApolloRay 's code (thanks for that!). I have also added single image and real-time image generation examples for sdxl-turbo. Check out my sdxl branch.
https://github.com/hkn-g/StreamDiffusion/tree/sdxl
sdxl-turbo img2img looks fine, but there are some other issues, such as TensorRT not working, and text2img giving an error with an SDXL model. However, in img2img mode, it works without an input image.