CFGpp
CFGpp copied to clipboard
Implementing for flow models?
I tried to implement this for flow models as described in the appendix, but the results are complete collapse (exploding images). Did I make a mistake or is this technique fundamentally incompatible with flow models (which have no renoising step)? Also the paper doesn't define v lambda.
def euler_cfgpp_update(
x_t: torch.Tensor,
t: float,
dt: float,
v_u: torch.Tensor,
v_c: torch.Tensor,
lambda_val: float,
) -> Tensor:
# Unconditional velocity at (x_t, t)
# v_u = model_uncond(x_t, t)
# Conditional velocity at (x_t, t)
# v_c = model_cond(x_t, t)
# Unconditional “Tweedie” estimate: x̃ₐ⁽∅⁾ = xₜ - t * v_u
x_null = x_t + (1 - t) * v_u
# Conditional “Tweedie” estimate: x̃ₐ⁽ᶜ⁾ = xₜ - t * v_c
x_cond = x_t + (1 - t) * v_c
# normal cfg prediction
# x_cfg = x_t + (1 - t) * (v_u + 2.3 * (v_c - v_u))
# CFG++ “Tweedie” estimate (interpolation):
# x̃ₐ⁽λ⁾ = (1-λ)* x̃ₐ⁽∅⁾ + λ * x̃ₐ⁽ᶜ⁾
x_cfgpp = x_null + lambda_val * (x_cond - x_null)
# Next time = t + dt
t_next = t + dt
# Euler step for CFG++:
# xₜ₁ = x̃ₐ⁽λ⁾(xₜ₀) + ( xₜ - x̃ₐ⁽∅⁾(xₜ₀) ) / t₀ * t₁
# (Make sure t != 0 to avoid divide-by-zero!)
# eps = 1e-12
x_next = x_cfgpp + (x_t - x_null) * ((1 - t_next) / (1 - t))
# vanilla cfg
# x_next = x_cfg + (x_t - x_cfg) * ((1 - t_next) / (1 - (t + eps)))
return x_next
@geonyeong-park @CFGpp-diffusion @jeongsol-kim
@zaptrem I think the implementation is different from the cfg++ for flow model, described in Appendix B https://openreview.net/pdf?id=E77uvbOTtp . Btw, I'm wondering the results of cfg++ w/ flow-matching (e.g. flux), but I can't find any result of it.
@zaptrem @LeeDoYup Thank you for your interest. You should check two points:
-
If $t$ is sampled from the scheduler of diffusers, its range is between 0 and 1000, whereas $\sigma_t$ is defined within the range [0,1]. Therefore, please verify the range of $t$, and if $t \in [0, 1000]$, use $t/1000$ instead of $t$ when computing Tweedie.
-
When using SD 3.0, the clean estimate should be $x_t - t \cdot v_u$. Please visualize your x_null and x_cond and confirm whether they represent clean images. Actually, the noise estimate is given by $x_t + (1 - t) \cdot v_u$.
While implementing CFG++ for Stable Diffusion 3.0-medium, we observed that a relatively small $\lambda$ (~0.1) works well. If you have any further questions, please feel free to open an issue.
Best,
@zaptrem @LeeDoYup Thank you for your interest. You should check two points:
If $t$ is sampled from the scheduler of diffusers, its range is between 0 and 1000, whereas $\sigma_t$ is defined within the range [0,1]. Therefore, please verify the range of $t$, and if $t \in [0, 1000]$, use $t/1000$ instead of $t$ when computing Tweedie.
When using SD 3.0, the clean estimate should be $x_t - t \cdot v_u$. Please visualize your x_null and x_cond and confirm whether they represent clean images. Actually, the noise estimate is given by $x_t + (1 - t) \cdot v_u$.
While implementing CFG++ for Stable Diffusion 3.0-medium, we observed that a relatively small $\lambda$ (~0.1) works well. If you have any further questions, please feel free to open an issue.
Best,
Thanks! I'll check this further in the morning, but before that I can confirm that in my setup t is defined between 0 (fully noise) and 1 (no noise). Could you share your SD3 implementation for reference?
Got it. I believe that setup is slightly different. In my case, $t = 1$ represents full noise, while $t = 0$ corresponds to a clean image. I attached my quick implementation below. Aside from minor bugs, it should work as expected.
When sampling with the prompt "A small cactus with a happy face in the Sahara Desert" and an image size of 768×768, CFG++ makes improvement. However, for a more accurate comparison, we should compute FID/LPIPS and use corresponding scales as done in the main manuscript.
| CFG (scale 7.5) | CFG++ (scale 0.15) |
|---|---|
Best,
class StableDiffusion3Base():
def __init__(self, model_key:str='stabilityai/stable-diffusion-3-medium-diffusers', device='cuda', dtype=torch.float16):
self.device = device
self.dtype = dtype
pipe = StableDiffusion3Pipeline.from_pretrained(model_key, torch_dtype=self.dtype)
self.scheduler = pipe.scheduler
self.tokenizer_1 = pipe.tokenizer
self.tokenizer_2 = pipe.tokenizer_2
self.tokenizer_3 = pipe.tokenizer_3
self.text_enc_1 = pipe.text_encoder
self.text_enc_2 = pipe.text_encoder_2
self.text_enc_3 = pipe.text_encoder_3
self.vae=pipe.vae
self.transformer = pipe.transformer
self.transformer.eval()
self.transformer.requires_grad_(False)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels)-1) if hasattr(self, "vae") and self.vae is not None else 8
)
del pipe
def encode_prompt(self, prompt: List[str], batch_size:int=1) -> List[torch.Tensor]:
'''
We assume that
1. number of tokens < max_length
2. one prompt for one image
'''
# CLIP encode (used for modulation of adaLN-zero)
# now, we have two CLIPs
text_clip1_ids = self.tokenizer_1(prompt,
padding="max_length",
max_length=77,
truncation=True,
return_tensors='pt').input_ids
text_clip1_emb = self.text_enc_1(text_clip1_ids.to(self.text_enc_1.device), output_hidden_states=True)
pool_clip1_emb = text_clip1_emb[0].to(dtype=self.dtype, device=self.text_enc_1.device)
text_clip1_emb = text_clip1_emb.hidden_states[-2].to(dtype=self.dtype, device=self.text_enc_1.device)
text_clip2_ids = self.tokenizer_2(prompt,
padding="max_length",
max_length=77,
truncation=True,
return_tensors='pt').input_ids
text_clip2_emb = self.text_enc_2(text_clip2_ids.to(self.text_enc_2.device), output_hidden_states=True)
pool_clip2_emb = text_clip2_emb[0].to(dtype=self.dtype, device=self.text_enc_2.device)
text_clip2_emb = text_clip2_emb.hidden_states[-2].to(dtype=self.dtype, device=self.text_enc_2.device)
# T5 encode (used for text condition)
text_t5_ids = self.tokenizer_3(prompt,
padding="max_length",
max_length=77,
truncation=True,
add_special_tokens=True,
return_tensors='pt').input_ids
text_t5_emb = self.text_enc_3(text_t5_ids.to(self.text_enc_3.device))[0]
text_t5_emb = text_t5_emb.to(dtype=self.dtype, device=self.text_enc_3.device)
# Merge
clip_prompt_emb = torch.cat([text_clip1_emb, text_clip2_emb], dim=-1)
clip_prompt_emb = torch.nn.functional.pad(
clip_prompt_emb, (0, text_t5_emb.shape[-1] - clip_prompt_emb.shape[-1])
)
prompt_emb = torch.cat([clip_prompt_emb, text_t5_emb], dim=-2)
pooled_prompt_emb = torch.cat([pool_clip1_emb, pool_clip2_emb], dim=-1)
return prompt_emb, pooled_prompt_emb
def initialize_latent(self, img_size:Tuple[int], batch_size:int=1, **kwargs):
H, W = img_size
lH, lW = H//self.vae_scale_factor, W//self.vae_scale_factor
lC = self.transformer.config.in_channels
latent_shape = (batch_size, lC, lH, lW)
z = torch.randn(latent_shape, device=self.device, dtype=self.dtype)
return z
def encode(self, image: torch.Tensor) -> torch.Tensor:
z = self.vae.encode(image).latent_dist.sample()
z = (z-self.vae.config.shift_factor) * self.vae.config.scaling_factor
return z
def decode(self, z: torch.Tensor) -> torch.Tensor:
z = (z/self.vae.config.scaling_factor) + self.vae.config.shift_factor
return self.vae.decode(z, return_dict=False)[0]
def predict_vector(self, z, t, prompt_emb, pooled_emb):
v = self.transformer(hidden_states=z,
timestep=t,
pooled_projections=pooled_emb,
encoder_hidden_states=prompt_emb,
return_dict=False)[0]
return v
class SD3Euler(StableDiffusion3Base):
def __init__(self, model_key:str='stabilityai/stable-diffusion-3-medium-diffusers', device='cuda'):
super().__init__(model_key=model_key, device=device)
def sample(self, prompts: List[str], NFE:int, img_shape: Optional[Tuple[int]]=None,
cfg_scale: float=1.0, batch_size: int = 1,
latent:Optional[List[torch.Tensor]]=None,
prompt_emb:Optional[List[torch.Tensor]]=None,
null_emb:Optional[List[torch.Tensor]]=None):
imgH, imgW = img_shape if img_shape is not None else (1024, 1024)
# encode text prompts
with torch.no_grad():
if prompt_emb is None:
prompt_emb, pooled_emb = self.encode_prompt(prompts, batch_size)
else:
prompt_emb, pooled_emb = prompt_emb[0], prompt_emb[1]
prompt_emb.to(self.transformer.device)
pooled_emb.to(self.transformer.device)
if null_emb is None:
null_prompt_emb, null_pooled_emb = self.encode_prompt([""], batch_size)
else:
null_prompt_emb, null_pooled_emb = null_emb[0], null_emb[1]
null_prompt_emb.to(self.transformer.device)
null_pooled_emb.to(self.transformer.device)
# initialize latent
if latent is None:
z = self.initialize_latent((imgH, imgW), batch_size)
else:
z = latent
# timesteps (default option. You can make your custom here.)
self.scheduler.set_timesteps(NFE, device=self.device)
timesteps = self.scheduler.timesteps
sigmas = timesteps / self.scheduler.config.num_train_timesteps
# Solve ODE
pbar = tqdm(timesteps, total=NFE, desc='SD3 Euler')
for i, t in enumerate(pbar):
timestep = t.expand(z.shape[0]).to(self.device)
pred_v = self.predict_vector(z, timestep, prompt_emb, pooled_emb)
if cfg_scale != 1.0:
pred_null_v = self.predict_vector(z, timestep, null_prompt_emb, null_pooled_emb)
else:
pred_null_v = 0.0
sigma = sigmas[i]
sigma_next = sigmas[i+1] if i+1 < NFE else 0.0
z0_null = z - sigma * pred_null_v
z0_cond = z - sigma * (pred_null_v + cfg_scale*(pred_v - pred_null_v))
z = z0_cond + (z - z0_null)/sigma*sigma_next
# decode
with torch.no_grad():
img = self.decode(z)
return img
@jeongsol-kim thanks for quick sharing the results & poc code ! It's amazing !
Got it. I believe that setup is slightly different. In my case, t = 1 represents full noise, while t = 0 corresponds to a clean image. I attached my quick implementation below. Aside from minor bugs, it should work as expected.
When sampling with the prompt "A small cactus with a happy face in the Sahara Desert" and an image size of 768×768, CFG++ makes improvement. However, for a more accurate comparison, we should compute FID/LPIPS and use corresponding scales as done in the main manuscript.
CFG (scale 7.5) CFG++ (scale 0.15)
![]()
Best,
class StableDiffusion3Base(): def __init__(self, model_key:str='stabilityai/stable-diffusion-3-medium-diffusers', device='cuda', dtype=torch.float16): self.device = device self.dtype = dtype pipe = StableDiffusion3Pipeline.from_pretrained(model_key, torch_dtype=self.dtype) self.scheduler = pipe.scheduler self.tokenizer_1 = pipe.tokenizer self.tokenizer_2 = pipe.tokenizer_2 self.tokenizer_3 = pipe.tokenizer_3 self.text_enc_1 = pipe.text_encoder self.text_enc_2 = pipe.text_encoder_2 self.text_enc_3 = pipe.text_encoder_3 self.vae=pipe.vae self.transformer = pipe.transformer self.transformer.eval() self.transformer.requires_grad_(False) self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels)-1) if hasattr(self, "vae") and self.vae is not None else 8 ) del pipe def encode_prompt(self, prompt: List[str], batch_size:int=1) -> List[torch.Tensor]: ''' We assume that 1. number of tokens < max_length 2. one prompt for one image ''' # CLIP encode (used for modulation of adaLN-zero) # now, we have two CLIPs text_clip1_ids = self.tokenizer_1(prompt, padding="max_length", max_length=77, truncation=True, return_tensors='pt').input_ids text_clip1_emb = self.text_enc_1(text_clip1_ids.to(self.text_enc_1.device), output_hidden_states=True) pool_clip1_emb = text_clip1_emb[0].to(dtype=self.dtype, device=self.text_enc_1.device) text_clip1_emb = text_clip1_emb.hidden_states[-2].to(dtype=self.dtype, device=self.text_enc_1.device) text_clip2_ids = self.tokenizer_2(prompt, padding="max_length", max_length=77, truncation=True, return_tensors='pt').input_ids text_clip2_emb = self.text_enc_2(text_clip2_ids.to(self.text_enc_2.device), output_hidden_states=True) pool_clip2_emb = text_clip2_emb[0].to(dtype=self.dtype, device=self.text_enc_2.device) text_clip2_emb = text_clip2_emb.hidden_states[-2].to(dtype=self.dtype, device=self.text_enc_2.device) # T5 encode (used for text condition) text_t5_ids = self.tokenizer_3(prompt, padding="max_length", max_length=77, truncation=True, add_special_tokens=True, return_tensors='pt').input_ids text_t5_emb = self.text_enc_3(text_t5_ids.to(self.text_enc_3.device))[0] text_t5_emb = text_t5_emb.to(dtype=self.dtype, device=self.text_enc_3.device) # Merge clip_prompt_emb = torch.cat([text_clip1_emb, text_clip2_emb], dim=-1) clip_prompt_emb = torch.nn.functional.pad( clip_prompt_emb, (0, text_t5_emb.shape[-1] - clip_prompt_emb.shape[-1]) ) prompt_emb = torch.cat([clip_prompt_emb, text_t5_emb], dim=-2) pooled_prompt_emb = torch.cat([pool_clip1_emb, pool_clip2_emb], dim=-1) return prompt_emb, pooled_prompt_emb def initialize_latent(self, img_size:Tuple[int], batch_size:int=1, **kwargs): H, W = img_size lH, lW = H//self.vae_scale_factor, W//self.vae_scale_factor lC = self.transformer.config.in_channels latent_shape = (batch_size, lC, lH, lW) z = torch.randn(latent_shape, device=self.device, dtype=self.dtype) return z def encode(self, image: torch.Tensor) -> torch.Tensor: z = self.vae.encode(image).latent_dist.sample() z = (z-self.vae.config.shift_factor) * self.vae.config.scaling_factor return z def decode(self, z: torch.Tensor) -> torch.Tensor: z = (z/self.vae.config.scaling_factor) + self.vae.config.shift_factor return self.vae.decode(z, return_dict=False)[0] def predict_vector(self, z, t, prompt_emb, pooled_emb): v = self.transformer(hidden_states=z, timestep=t, pooled_projections=pooled_emb, encoder_hidden_states=prompt_emb, return_dict=False)[0] return v class SD3Euler(StableDiffusion3Base): def __init__(self, model_key:str='stabilityai/stable-diffusion-3-medium-diffusers', device='cuda'): super().__init__(model_key=model_key, device=device) def sample(self, prompts: List[str], NFE:int, img_shape: Optional[Tuple[int]]=None, cfg_scale: float=1.0, batch_size: int = 1, latent:Optional[List[torch.Tensor]]=None, prompt_emb:Optional[List[torch.Tensor]]=None, null_emb:Optional[List[torch.Tensor]]=None): imgH, imgW = img_shape if img_shape is not None else (1024, 1024) # encode text prompts with torch.no_grad(): if prompt_emb is None: prompt_emb, pooled_emb = self.encode_prompt(prompts, batch_size) else: prompt_emb, pooled_emb = prompt_emb[0], prompt_emb[1] prompt_emb.to(self.transformer.device) pooled_emb.to(self.transformer.device) if null_emb is None: null_prompt_emb, null_pooled_emb = self.encode_prompt([""], batch_size) else: null_prompt_emb, null_pooled_emb = null_emb[0], null_emb[1] null_prompt_emb.to(self.transformer.device) null_pooled_emb.to(self.transformer.device) # initialize latent if latent is None: z = self.initialize_latent((imgH, imgW), batch_size) else: z = latent # timesteps (default option. You can make your custom here.) self.scheduler.set_timesteps(NFE, device=self.device) timesteps = self.scheduler.timesteps sigmas = timesteps / self.scheduler.config.num_train_timesteps # Solve ODE pbar = tqdm(timesteps, total=NFE, desc='SD3 Euler') for i, t in enumerate(pbar): timestep = t.expand(z.shape[0]).to(self.device) pred_v = self.predict_vector(z, timestep, prompt_emb, pooled_emb) if cfg_scale != 1.0: pred_null_v = self.predict_vector(z, timestep, null_prompt_emb, null_pooled_emb) else: pred_null_v = 0.0 sigma = sigmas[i] sigma_next = sigmas[i+1] if i+1 < NFE else 0.0 z0_null = z - sigma * pred_null_v z0_cond = z - sigma * (pred_null_v + cfg_scale*(pred_v - pred_null_v)) z = z0_cond + (z - z0_null)/sigma*sigma_next # decode with torch.no_grad(): img = self.decode(z) return img
Can you share what the cactus looks like at extreme scales (e.g., 0.95, 0.01)? I'm curious whether this actually completely solves the oversaturation/overshooting manifold problem or simply reduces it.
Have you tried it on the SD3.5 Large model? I used cfg_scale=0.15 and encountered distortion issues. My prompt: "Captain America stands by the lake."
Hi everyone @zaptrem @LeeDoYup @SCUTykLin,
Apologies for late response. We find that correct implementation is as shown below. With this, the proper scales match those reported in our paper (i.e. 0.6, 0.8).
What is different?
- A Euler update could be decomposed into clean image estimation and noise estimation.
- For clean image estimation, use conditioned velocity.
- For noise estimation, use unconditional velocity.
class StableDiffusion3Base():
def __init__(self, model_key:str='stabilityai/stable-diffusion-3-medium-diffusers', device='cuda', dtype=torch.float16):
self.device = device
self.dtype = dtype
pipe = StableDiffusion3Pipeline.from_pretrained(model_key, torch_dtype=self.dtype)
self.scheduler = pipe.scheduler
self.tokenizer_1 = pipe.tokenizer
self.tokenizer_2 = pipe.tokenizer_2
self.tokenizer_3 = pipe.tokenizer_3
self.text_enc_1 = pipe.text_encoder
self.text_enc_2 = pipe.text_encoder_2
self.text_enc_3 = pipe.text_encoder_3
self.vae=pipe.vae
self.transformer = pipe.transformer
self.transformer.eval()
self.transformer.requires_grad_(False)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels)-1) if hasattr(self, "vae") and self.vae is not None else 8
)
del pipe
def encode_prompt(self, prompt: List[str], batch_size:int=1) -> List[torch.Tensor]:
'''
We assume that
1. number of tokens < max_length
2. one prompt for one image
'''
# CLIP encode (used for modulation of adaLN-zero)
# now, we have two CLIPs
text_clip1_ids = self.tokenizer_1(prompt,
padding="max_length",
max_length=77,
truncation=True,
return_tensors='pt').input_ids
text_clip1_emb = self.text_enc_1(text_clip1_ids.to(self.text_enc_1.device), output_hidden_states=True)
pool_clip1_emb = text_clip1_emb[0].to(dtype=self.dtype, device=self.text_enc_1.device)
text_clip1_emb = text_clip1_emb.hidden_states[-2].to(dtype=self.dtype, device=self.text_enc_1.device)
text_clip2_ids = self.tokenizer_2(prompt,
padding="max_length",
max_length=77,
truncation=True,
return_tensors='pt').input_ids
text_clip2_emb = self.text_enc_2(text_clip2_ids.to(self.text_enc_2.device), output_hidden_states=True)
pool_clip2_emb = text_clip2_emb[0].to(dtype=self.dtype, device=self.text_enc_2.device)
text_clip2_emb = text_clip2_emb.hidden_states[-2].to(dtype=self.dtype, device=self.text_enc_2.device)
# T5 encode (used for text condition)
text_t5_ids = self.tokenizer_3(prompt,
padding="max_length",
max_length=77,
truncation=True,
add_special_tokens=True,
return_tensors='pt').input_ids
text_t5_emb = self.text_enc_3(text_t5_ids.to(self.text_enc_3.device))[0]
text_t5_emb = text_t5_emb.to(dtype=self.dtype, device=self.text_enc_3.device)
# Merge
clip_prompt_emb = torch.cat([text_clip1_emb, text_clip2_emb], dim=-1)
clip_prompt_emb = torch.nn.functional.pad(
clip_prompt_emb, (0, text_t5_emb.shape[-1] - clip_prompt_emb.shape[-1])
)
prompt_emb = torch.cat([clip_prompt_emb, text_t5_emb], dim=-2)
pooled_prompt_emb = torch.cat([pool_clip1_emb, pool_clip2_emb], dim=-1)
return prompt_emb, pooled_prompt_emb
def initialize_latent(self, img_size:Tuple[int], batch_size:int=1, **kwargs):
H, W = img_size
lH, lW = H//self.vae_scale_factor, W//self.vae_scale_factor
lC = self.transformer.config.in_channels
latent_shape = (batch_size, lC, lH, lW)
z = torch.randn(latent_shape, device=self.device, dtype=self.dtype)
return z
def encode(self, image: torch.Tensor) -> torch.Tensor:
z = self.vae.encode(image).latent_dist.sample()
z = (z-self.vae.config.shift_factor) * self.vae.config.scaling_factor
return z
def decode(self, z: torch.Tensor) -> torch.Tensor:
z = (z/self.vae.config.scaling_factor) + self.vae.config.shift_factor
return self.vae.decode(z, return_dict=False)[0]
def predict_vector(self, z, t, prompt_emb, pooled_emb):
v = self.transformer(hidden_states=z,
timestep=t,
pooled_projections=pooled_emb,
encoder_hidden_states=prompt_emb,
return_dict=False)[0]
return v
class SD3Euler(StableDiffusion3Base):
def __init__(self, model_key:str='stabilityai/stable-diffusion-3-medium-diffusers', device='cuda'):
super().__init__(model_key=model_key, device=device)
def sample(self, prompts: List[str], NFE:int, img_shape: Optional[Tuple[int]]=None,
cfg_scale: float=1.0, batch_size: int = 1,
latent:Optional[List[torch.Tensor]]=None,
prompt_emb:Optional[List[torch.Tensor]]=None,
null_emb:Optional[List[torch.Tensor]]=None):
imgH, imgW = img_shape if img_shape is not None else (1024, 1024)
# encode text prompts
with torch.no_grad():
if prompt_emb is None:
prompt_emb, pooled_emb = self.encode_prompt(prompts, batch_size)
else:
prompt_emb, pooled_emb = prompt_emb[0], prompt_emb[1]
prompt_emb.to(self.transformer.device)
pooled_emb.to(self.transformer.device)
if null_emb is None:
null_prompt_emb, null_pooled_emb = self.encode_prompt([""], batch_size)
else:
null_prompt_emb, null_pooled_emb = null_emb[0], null_emb[1]
null_prompt_emb.to(self.transformer.device)
null_pooled_emb.to(self.transformer.device)
# initialize latent
if latent is None:
z = self.initialize_latent((imgH, imgW), batch_size)
else:
z = latent
# timesteps (default option. You can make your custom here.)
self.scheduler.set_timesteps(NFE, device=self.device)
timesteps = self.scheduler.timesteps
sigmas = timesteps / self.scheduler.config.num_train_timesteps
# Solve ODE
pbar = tqdm(timesteps, total=NFE, desc='SD3 Euler')
for i, t in enumerate(pbar):
timestep = t.expand(z.shape[0]).to(self.device)
pred_v = self.predict_vector(z, timestep, prompt_emb, pooled_emb)
if cfg_scale != 1.0:
pred_null_v = self.predict_vector(z, timestep, null_prompt_emb, null_pooled_emb)
else:
pred_null_v = 0.0
sigma = sigmas[i]
sigma_next = sigmas[i+1] if i+1 < NFE else 0.0
##### Changed #######
z0t = z - sigma * (pred_null_v + cfg_scale*(pred_v - pred_null_v))
z1t = z + (1-sigma) * pred_null_v
z = (1-sigma_next) * z0t + sigma_next * z1t
####################
# decode
with torch.no_grad():
img = self.decode(z)
return img
Have you tried it on the SD3.5 Large model? I used cfg_scale=0.15 and encountered distortion issues. My prompt: "Captain America stands by the lake."
I found I did not
Hi everyone @zaptrem @LeeDoYup @SCUTykLin,
Apologies for late response. We find that correct implementation is as shown below. With this, the proper scales match those reported in our paper (i.e. 0.6, 0.8).
What is different?
- A Euler update could be decomposed into clean image estimation and noise estimation.
- For clean image estimation, use conditioned velocity.
- For noise estimation, use unconditional velocity.
class StableDiffusion3Base(): def __init__(self, model_key:str='stabilityai/stable-diffusion-3-medium-diffusers', device='cuda', dtype=torch.float16): self.device = device self.dtype = dtype pipe = StableDiffusion3Pipeline.from_pretrained(model_key, torch_dtype=self.dtype) self.scheduler = pipe.scheduler self.tokenizer_1 = pipe.tokenizer self.tokenizer_2 = pipe.tokenizer_2 self.tokenizer_3 = pipe.tokenizer_3 self.text_enc_1 = pipe.text_encoder self.text_enc_2 = pipe.text_encoder_2 self.text_enc_3 = pipe.text_encoder_3 self.vae=pipe.vae self.transformer = pipe.transformer self.transformer.eval() self.transformer.requires_grad_(False) self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels)-1) if hasattr(self, "vae") and self.vae is not None else 8 ) del pipe def encode_prompt(self, prompt: List[str], batch_size:int=1) -> List[torch.Tensor]: ''' We assume that 1. number of tokens < max_length 2. one prompt for one image ''' # CLIP encode (used for modulation of adaLN-zero) # now, we have two CLIPs text_clip1_ids = self.tokenizer_1(prompt, padding="max_length", max_length=77, truncation=True, return_tensors='pt').input_ids text_clip1_emb = self.text_enc_1(text_clip1_ids.to(self.text_enc_1.device), output_hidden_states=True) pool_clip1_emb = text_clip1_emb[0].to(dtype=self.dtype, device=self.text_enc_1.device) text_clip1_emb = text_clip1_emb.hidden_states[-2].to(dtype=self.dtype, device=self.text_enc_1.device) text_clip2_ids = self.tokenizer_2(prompt, padding="max_length", max_length=77, truncation=True, return_tensors='pt').input_ids text_clip2_emb = self.text_enc_2(text_clip2_ids.to(self.text_enc_2.device), output_hidden_states=True) pool_clip2_emb = text_clip2_emb[0].to(dtype=self.dtype, device=self.text_enc_2.device) text_clip2_emb = text_clip2_emb.hidden_states[-2].to(dtype=self.dtype, device=self.text_enc_2.device) # T5 encode (used for text condition) text_t5_ids = self.tokenizer_3(prompt, padding="max_length", max_length=77, truncation=True, add_special_tokens=True, return_tensors='pt').input_ids text_t5_emb = self.text_enc_3(text_t5_ids.to(self.text_enc_3.device))[0] text_t5_emb = text_t5_emb.to(dtype=self.dtype, device=self.text_enc_3.device) # Merge clip_prompt_emb = torch.cat([text_clip1_emb, text_clip2_emb], dim=-1) clip_prompt_emb = torch.nn.functional.pad( clip_prompt_emb, (0, text_t5_emb.shape[-1] - clip_prompt_emb.shape[-1]) ) prompt_emb = torch.cat([clip_prompt_emb, text_t5_emb], dim=-2) pooled_prompt_emb = torch.cat([pool_clip1_emb, pool_clip2_emb], dim=-1) return prompt_emb, pooled_prompt_emb def initialize_latent(self, img_size:Tuple[int], batch_size:int=1, **kwargs): H, W = img_size lH, lW = H//self.vae_scale_factor, W//self.vae_scale_factor lC = self.transformer.config.in_channels latent_shape = (batch_size, lC, lH, lW) z = torch.randn(latent_shape, device=self.device, dtype=self.dtype) return z def encode(self, image: torch.Tensor) -> torch.Tensor: z = self.vae.encode(image).latent_dist.sample() z = (z-self.vae.config.shift_factor) * self.vae.config.scaling_factor return z def decode(self, z: torch.Tensor) -> torch.Tensor: z = (z/self.vae.config.scaling_factor) + self.vae.config.shift_factor return self.vae.decode(z, return_dict=False)[0] def predict_vector(self, z, t, prompt_emb, pooled_emb): v = self.transformer(hidden_states=z, timestep=t, pooled_projections=pooled_emb, encoder_hidden_states=prompt_emb, return_dict=False)[0] return v class SD3Euler(StableDiffusion3Base): def __init__(self, model_key:str='stabilityai/stable-diffusion-3-medium-diffusers', device='cuda'): super().__init__(model_key=model_key, device=device) def sample(self, prompts: List[str], NFE:int, img_shape: Optional[Tuple[int]]=None, cfg_scale: float=1.0, batch_size: int = 1, latent:Optional[List[torch.Tensor]]=None, prompt_emb:Optional[List[torch.Tensor]]=None, null_emb:Optional[List[torch.Tensor]]=None): imgH, imgW = img_shape if img_shape is not None else (1024, 1024) # encode text prompts with torch.no_grad(): if prompt_emb is None: prompt_emb, pooled_emb = self.encode_prompt(prompts, batch_size) else: prompt_emb, pooled_emb = prompt_emb[0], prompt_emb[1] prompt_emb.to(self.transformer.device) pooled_emb.to(self.transformer.device) if null_emb is None: null_prompt_emb, null_pooled_emb = self.encode_prompt([""], batch_size) else: null_prompt_emb, null_pooled_emb = null_emb[0], null_emb[1] null_prompt_emb.to(self.transformer.device) null_pooled_emb.to(self.transformer.device) # initialize latent if latent is None: z = self.initialize_latent((imgH, imgW), batch_size) else: z = latent # timesteps (default option. You can make your custom here.) self.scheduler.set_timesteps(NFE, device=self.device) timesteps = self.scheduler.timesteps sigmas = timesteps / self.scheduler.config.num_train_timesteps # Solve ODE pbar = tqdm(timesteps, total=NFE, desc='SD3 Euler') for i, t in enumerate(pbar): timestep = t.expand(z.shape[0]).to(self.device) pred_v = self.predict_vector(z, timestep, prompt_emb, pooled_emb) if cfg_scale != 1.0: pred_null_v = self.predict_vector(z, timestep, null_prompt_emb, null_pooled_emb) else: pred_null_v = 0.0 sigma = sigmas[i] sigma_next = sigmas[i+1] if i+1 < NFE else 0.0 ##### Changed ####### z0t = z - sigma * (pred_null_v + cfg_scale*(pred_v - pred_null_v)) z1t = z + (1-sigma) * pred_null_v z = (1-sigma_next) * z0t + sigma_next * z1t #################### # decode with torch.no_grad(): img = self.decode(z) return img
Great! Thanks for your reply.
But I have some questions about your code—it doesn't seem to align with the algorithm described in your paper.
I think that there are three steps:
(1) $\epsilon_c^{\lambda} = \epsilon_{\emptyset}(x_t) + \lambda[\epsilon_c(x_t) - \epsilon_{\emptyset}(x_t)] $. (2) $x_c^{\lambda} = (x_t - \sqrt{1 - \alpha_t}\epsilon_c^{\lambda})\sqrt{\alpha_t}$. Flow Matching: $x_t = x_0 + \sigma_t (\epsilon - x_0) \rightarrow \hat{x_0} = x_t - \sigma_t v_t^c$ (3) $x_{t-1} = \sqrt{\alpha_{t-1}}x_c^{\lambda}(x_t) + \sqrt{1 - \alpha_{t-1}} \epsilon_{\emptyset}(x_t)$. Flow Matching: $x_{t+1} = \hat{x_0} + \sigma_{t+1} v_t^u$
However, in the code which u provide:
z0t = z - sigma * (pred_null_v + cfg_scale*(pred_v - pred_null_v))
is the $x_c^{\lambda} = (x_t - \sqrt{1 - \alpha_t}\epsilon_c^{\lambda})\sqrt{\alpha_t}$. What is the meaning of the next 2 lines?
z1t = z + (1-sigma) * pred_null_v
z = (1-sigma_next) * z0t + sigma_next * z1t
Thanks.
=============================分割线=================================== I found the reason:
z1t = z + (1-sigma) * pred_null_v
$$ \begin{align} z_1^t &= z + (1 - \sigma_t) * v_t^u \ &= (1 - \sigma_t) z_0 + \sigma_t \epsilon + (1 - \sigma_t) (\epsilon - z_0) \ &= \epsilon (\text{predict noise of uncondition}) \ z_{t+1} &= (1 - \sigma_{t+1}) z_t^0 + \sigma_{t+1} (z_1^t) \end{align} $$
But why do we need to isolate and discuss noise separately? Can't we treat noise $\epsilon$ as part of the velocity$v$ calculation? BTW, as first I tried treating speed as noise in the calculation, which caused an oversaturation issue, but switching to noise worked normally.
Thanks!
Hi everyone @zaptrem @LeeDoYup @SCUTykLin,
Apologies for late response. We find that correct implementation is as shown below. With this, the proper scales match those reported in our paper (i.e. 0.6, 0.8).
What is different?
- A Euler update could be decomposed into clean image estimation and noise estimation.
- For clean image estimation, use conditioned velocity.
- For noise estimation, use unconditional velocity.
class StableDiffusion3Base(): def __init__(self, model_key:str='stabilityai/stable-diffusion-3-medium-diffusers', device='cuda', dtype=torch.float16): self.device = device self.dtype = dtype pipe = StableDiffusion3Pipeline.from_pretrained(model_key, torch_dtype=self.dtype) self.scheduler = pipe.scheduler self.tokenizer_1 = pipe.tokenizer self.tokenizer_2 = pipe.tokenizer_2 self.tokenizer_3 = pipe.tokenizer_3 self.text_enc_1 = pipe.text_encoder self.text_enc_2 = pipe.text_encoder_2 self.text_enc_3 = pipe.text_encoder_3 self.vae=pipe.vae self.transformer = pipe.transformer self.transformer.eval() self.transformer.requires_grad_(False) self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels)-1) if hasattr(self, "vae") and self.vae is not None else 8 ) del pipe def encode_prompt(self, prompt: List[str], batch_size:int=1) -> List[torch.Tensor]: ''' We assume that 1. number of tokens < max_length 2. one prompt for one image ''' # CLIP encode (used for modulation of adaLN-zero) # now, we have two CLIPs text_clip1_ids = self.tokenizer_1(prompt, padding="max_length", max_length=77, truncation=True, return_tensors='pt').input_ids text_clip1_emb = self.text_enc_1(text_clip1_ids.to(self.text_enc_1.device), output_hidden_states=True) pool_clip1_emb = text_clip1_emb[0].to(dtype=self.dtype, device=self.text_enc_1.device) text_clip1_emb = text_clip1_emb.hidden_states[-2].to(dtype=self.dtype, device=self.text_enc_1.device) text_clip2_ids = self.tokenizer_2(prompt, padding="max_length", max_length=77, truncation=True, return_tensors='pt').input_ids text_clip2_emb = self.text_enc_2(text_clip2_ids.to(self.text_enc_2.device), output_hidden_states=True) pool_clip2_emb = text_clip2_emb[0].to(dtype=self.dtype, device=self.text_enc_2.device) text_clip2_emb = text_clip2_emb.hidden_states[-2].to(dtype=self.dtype, device=self.text_enc_2.device) # T5 encode (used for text condition) text_t5_ids = self.tokenizer_3(prompt, padding="max_length", max_length=77, truncation=True, add_special_tokens=True, return_tensors='pt').input_ids text_t5_emb = self.text_enc_3(text_t5_ids.to(self.text_enc_3.device))[0] text_t5_emb = text_t5_emb.to(dtype=self.dtype, device=self.text_enc_3.device) # Merge clip_prompt_emb = torch.cat([text_clip1_emb, text_clip2_emb], dim=-1) clip_prompt_emb = torch.nn.functional.pad( clip_prompt_emb, (0, text_t5_emb.shape[-1] - clip_prompt_emb.shape[-1]) ) prompt_emb = torch.cat([clip_prompt_emb, text_t5_emb], dim=-2) pooled_prompt_emb = torch.cat([pool_clip1_emb, pool_clip2_emb], dim=-1) return prompt_emb, pooled_prompt_emb def initialize_latent(self, img_size:Tuple[int], batch_size:int=1, **kwargs): H, W = img_size lH, lW = H//self.vae_scale_factor, W//self.vae_scale_factor lC = self.transformer.config.in_channels latent_shape = (batch_size, lC, lH, lW) z = torch.randn(latent_shape, device=self.device, dtype=self.dtype) return z def encode(self, image: torch.Tensor) -> torch.Tensor: z = self.vae.encode(image).latent_dist.sample() z = (z-self.vae.config.shift_factor) * self.vae.config.scaling_factor return z def decode(self, z: torch.Tensor) -> torch.Tensor: z = (z/self.vae.config.scaling_factor) + self.vae.config.shift_factor return self.vae.decode(z, return_dict=False)[0] def predict_vector(self, z, t, prompt_emb, pooled_emb): v = self.transformer(hidden_states=z, timestep=t, pooled_projections=pooled_emb, encoder_hidden_states=prompt_emb, return_dict=False)[0] return v class SD3Euler(StableDiffusion3Base): def __init__(self, model_key:str='stabilityai/stable-diffusion-3-medium-diffusers', device='cuda'): super().__init__(model_key=model_key, device=device) def sample(self, prompts: List[str], NFE:int, img_shape: Optional[Tuple[int]]=None, cfg_scale: float=1.0, batch_size: int = 1, latent:Optional[List[torch.Tensor]]=None, prompt_emb:Optional[List[torch.Tensor]]=None, null_emb:Optional[List[torch.Tensor]]=None): imgH, imgW = img_shape if img_shape is not None else (1024, 1024) # encode text prompts with torch.no_grad(): if prompt_emb is None: prompt_emb, pooled_emb = self.encode_prompt(prompts, batch_size) else: prompt_emb, pooled_emb = prompt_emb[0], prompt_emb[1] prompt_emb.to(self.transformer.device) pooled_emb.to(self.transformer.device) if null_emb is None: null_prompt_emb, null_pooled_emb = self.encode_prompt([""], batch_size) else: null_prompt_emb, null_pooled_emb = null_emb[0], null_emb[1] null_prompt_emb.to(self.transformer.device) null_pooled_emb.to(self.transformer.device) # initialize latent if latent is None: z = self.initialize_latent((imgH, imgW), batch_size) else: z = latent # timesteps (default option. You can make your custom here.) self.scheduler.set_timesteps(NFE, device=self.device) timesteps = self.scheduler.timesteps sigmas = timesteps / self.scheduler.config.num_train_timesteps # Solve ODE pbar = tqdm(timesteps, total=NFE, desc='SD3 Euler') for i, t in enumerate(pbar): timestep = t.expand(z.shape[0]).to(self.device) pred_v = self.predict_vector(z, timestep, prompt_emb, pooled_emb) if cfg_scale != 1.0: pred_null_v = self.predict_vector(z, timestep, null_prompt_emb, null_pooled_emb) else: pred_null_v = 0.0 sigma = sigmas[i] sigma_next = sigmas[i+1] if i+1 < NFE else 0.0 ##### Changed ####### z0t = z - sigma * (pred_null_v + cfg_scale*(pred_v - pred_null_v)) z1t = z + (1-sigma) * pred_null_v z = (1-sigma_next) * z0t + sigma_next * z1t #################### # decode with torch.no_grad(): img = self.decode(z) return img
Separately discussing noise has given me a lot of inspiration!Thank you very very much, author. Wish u makes a fortune overnight!
Best,