Sana
Sana copied to clipboard
Image Editting via Inversion
Hi, please correct me if I'm wrong. I tried using the inverse function in DPM-Solver to invert the source latent to the noisy latent. After obtaining the noisy latent, I use the sample function to get the edited image. However, when I use the inverse function, the noisy latent I got is all of nan value. I leave the code below, please have a check.
import argparse
import torch
import sys
import os
import hashlib
import json
addpath = os.path.join('/'.join(os.path.dirname(os.path.abspath(__file__)).split('/')[:-1]), 'submodule/Sana')
sys.path.append(addpath)
from torch import Tensor
from app.sana_pipeline import SanaPipeline, classify_height_width_bin, guidance_type_select
from diffusion.data.datasets.utils import (
ASPECT_RATIO_512_TEST,
ASPECT_RATIO_1024_TEST,
ASPECT_RATIO_2048_TEST,
ASPECT_RATIO_4096_TEST,
)
from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_encode, vae_decode
from diffusion.model.utils import get_weight_dtype, prepare_prompt_ar, resize_and_crop_tensor
from diffusion.utils.config import SanaConfig, model_init_config
from diffusion.utils.logger import get_root_logger
from diffusion.model import gaussian_diffusion as gd
from diffusion.model.dpm_solver import DPM_Solver, NoiseScheduleFlow, NoiseScheduleVP, model_wrapper
class CustomDPM_Solver(DPM_Solver):
def __init__(
self,
model_fn,
noise_schedule,
algorithm_type="dpmsolver++",
correcting_x0_fn=None,
correcting_xt_fn=None,
thresholding_max_val=1.0,
dynamic_thresholding_ratio=0.995,
):
super().__init__(
model_fn,
noise_schedule,
algorithm_type=algorithm_type,
correcting_x0_fn=correcting_x0_fn,
correcting_xt_fn=correcting_xt_fn,
thresholding_max_val=thresholding_max_val,
dynamic_thresholding_ratio=dynamic_thresholding_ratio,
)
def inverse(
self,
x,
steps=20,
t_start=None,
t_end=None,
order=2,
skip_type="time_uniform",
method="multistep",
lower_order_final=True,
denoise_to_zero=False,
solver_type="dpmsolver",
atol=0.0078,
rtol=0.05,
return_intermediate=False,
flow_shift=1.0,
):
"""
Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver.
For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training.
"""
t_0 = 1.0 / self.noise_schedule.total_N if t_start is None else t_start
t_T = self.noise_schedule.T if t_end is None else t_end
assert (
t_0 > 0 and t_T > 0
), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
return self.sample(
x,
steps=steps,
t_start=t_0,
t_end=t_T,
order=order,
skip_type=skip_type,
method=method,
lower_order_final=lower_order_final,
denoise_to_zero=denoise_to_zero,
solver_type=solver_type,
atol=atol,
rtol=rtol,
return_intermediate=return_intermediate,
flow_shift=flow_shift,
)
def DPMS(
model,
condition,
uncondition,
cfg_scale,
pag_scale=1.0,
pag_applied_layers=None,
model_type="noise", # or "x_start" or "v" or "score", "flow"
noise_schedule="linear",
guidance_type="classifier-free",
model_kwargs=None,
diffusion_steps=1000,
schedule="VP",
interval_guidance=None,
):
if pag_applied_layers is None:
pag_applied_layers = []
if model_kwargs is None:
model_kwargs = {}
if interval_guidance is None:
interval_guidance = [0, 1.0]
betas = torch.tensor(gd.get_named_beta_schedule(noise_schedule, diffusion_steps))
## 1. Define the noise schedule.
if schedule == "VP":
noise_schedule = NoiseScheduleVP(schedule="discrete", betas=betas)
elif schedule == "FLOW":
noise_schedule = NoiseScheduleFlow(schedule="discrete_flow")
## 2. Convert your discrete-time `model` to the continuous-time
## noise prediction model. Here is an example for a diffusion model
## `model` with the noise prediction type ("noise") .
model_fn = model_wrapper(
model,
noise_schedule,
model_type=model_type,
model_kwargs=model_kwargs,
guidance_type=guidance_type,
pag_scale=pag_scale,
pag_applied_layers=pag_applied_layers,
condition=condition,
unconditional_condition=uncondition,
guidance_scale=cfg_scale,
interval_guidance=interval_guidance,
)
## 3. Define dpm-solver and sample by multistep DPM-Solver.
return CustomDPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
class DPMInversePipeline(SanaPipeline):
def __init__(self, config_path):
super().__init__(config_path)
@torch.inference_mode()
def prepare_prompt(self, prompts):
if not self.config.text_encoder.chi_prompt:
max_length_all = self.config.text_encoder.model_max_length
prompts_all = prompts
else:
chi_prompt = "\n".join(self.config.text_encoder.chi_prompt)
prompts_all = [chi_prompt + prompt for prompt in prompts]
num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
max_length_all = (
num_chi_prompt_tokens + self.config.text_encoder.model_max_length - 2
) # magic number 2: [bos], [_]
caption_token = self.tokenizer(
prompts_all,
max_length=max_length_all,
padding="max_length",
truncation=True,
return_tensors="pt",
).to(device=self.device)
select_index = [0] + list(range(-self.config.text_encoder.model_max_length + 1, 0))
caption_embs = self.text_encoder(caption_token.input_ids, caption_token.attention_mask)[0][:, None][
:, :, select_index
].to(self.weight_dtype)
emb_masks = caption_token.attention_mask[:, select_index]
return caption_embs, emb_masks
@torch.inference_mode()
def prepare_scheduler(self, caption_embs, null_y, guidance_scale, pag_guidance_scale, hw, ar, emb_masks):
model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks)
if self.vis_sampler == "flow_euler":
raise NotImplementedError("Flow Euler is not supported for editing.")
elif self.vis_sampler == "flow_dpm-solver":
scheduler = DPMS(
self.model,
condition=caption_embs,
uncondition=null_y,
guidance_type=self.guidance_type,
cfg_scale=guidance_scale,
pag_scale=pag_guidance_scale,
pag_applied_layers=self.config.model.pag_applied_layers,
model_type="flow",
model_kwargs=model_kwargs,
schedule="FLOW",
)
scheduler.register_progress_bar(self.progress_fn)
return scheduler
else:
raise ValueError(f"Unsupported sampler: {self.vis_sampler}")
@torch.inference_mode()
def edit(
self,
src_prompt: list | str = None,
tgt_prompt: list | str =None,
src_img: list[Tensor] = None,
height=1024,
width=1024,
negative_prompt="",
num_inversion_steps=5,
num_inference_steps=20,
guidance_scale=4.5,
pag_guidance_scale=1.0,
generator=torch.Generator().manual_seed(42),
use_resolution_binning=True,
):
self.ori_height, self.ori_width = height, width
if use_resolution_binning:
self.height, self.width = classify_height_width_bin(height, width, ratios=self.base_ratios)
else:
self.height, self.width = height, width
self.latent_size_h, self.latent_size_w = (
self.height // self.config.vae.vae_downsample_rate,
self.width // self.config.vae.vae_downsample_rate,
)
self.guidance_type = guidance_type_select(self.guidance_type, pag_guidance_scale, self.config.model.attn_type)
# 1. pre-compute negative embedding
if negative_prompt != "":
null_caption_token = self.tokenizer(
negative_prompt,
max_length=self.max_sequence_length,
padding="max_length",
truncation=True,
return_tensors="pt",
).to(self.device)
self.null_caption_embs = self.text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[0]
if src_prompt is None or tgt_prompt is None or src_img is None:
raise ValueError("src_prompt, tgt_prompt and src_img must be provided.")
src_prompts = src_prompt if isinstance(src_prompt, list) else [src_prompt]
tgt_prompts = tgt_prompt if isinstance(tgt_prompt, list) else [tgt_prompt]
src_imgs = src_img if isinstance(src_img, list) else [src_img]
samples = []
for sprompt, tprompt, imgs in zip(src_prompts, tgt_prompts, src_imgs):
# data prepare
num_images_per_prompt = imgs.size(0)
sprompts, tprompts, hw, ar = (
[], [],
torch.tensor([[self.image_size, self.image_size]], dtype=torch.float, device=self.device).repeat(num_images_per_prompt, 1),
torch.tensor([[1.0]], device=self.device).repeat(num_images_per_prompt, 1),
)
for _ in range(num_images_per_prompt):
sprompts.append(prepare_prompt_ar(sprompt, self.base_ratios, device=self.device, show=False)[0].strip())
tprompts.append(prepare_prompt_ar(tprompt, self.base_ratios, device=self.device, show=False)[0].strip())
with torch.no_grad():
# prepare text feature
src_caption_embs, scr_emb_masks = self.prepare_prompt(sprompts)
tgt_caption_embs, tgt_emb_masks = self.prepare_prompt(tprompts)
null_y = self.null_caption_embs.repeat(len(sprompts), 1, 1)[:, None].to(self.weight_dtype)
# inversion step
scheduler = self.prepare_scheduler(src_caption_embs, null_y, guidance_scale, pag_guidance_scale, hw=hw, ar=ar, emb_masks=scr_emb_masks)
latent = vae_encode(self.config.vae.vae_type, self.vae, imgs, False, self.device)
noisy_latent = scheduler.inverse(
x = latent,
steps=num_inversion_steps,
order=2,
skip_type="time_uniform_flow",
method="multistep",
flow_shift=self.flow_shift,
)
print(noisy_latent.max(), noisy_latent.min(), noisy_latent.mean(), noisy_latent.shape)
# sampling
scheduler = self.prepare_scheduler(tgt_caption_embs, null_y, guidance_scale, pag_guidance_scale, hw=hw, ar=ar, emb_masks=tgt_emb_masks)
sample = scheduler.sample(
noisy_latent,
steps=num_inference_steps,
order=2,
skip_type="time_uniform_flow",
method="multistep",
flow_shift=self.flow_shift
)
sample = sample.to(self.vae_dtype)
with torch.no_grad():
sample = vae_decode(self.config.vae.vae_type, self.vae, sample)
if use_resolution_binning:
sample = resize_and_crop_tensor(sample, self.ori_width, self.ori_height)
samples.append(sample)
return samples
@torch.inference_mode()
def forward(
self,
prompt=None,
height=1024,
width=1024,
negative_prompt="",
num_inference_steps=20,
guidance_scale=4.5,
pag_guidance_scale=1.0,
num_images_per_prompt=1,
generator=torch.Generator().manual_seed(42),
latents=None,
use_resolution_binning=True,
):
self.ori_height, self.ori_width = height, width
if use_resolution_binning:
self.height, self.width = classify_height_width_bin(height, width, ratios=self.base_ratios)
else:
self.height, self.width = height, width
self.latent_size_h, self.latent_size_w = (
self.height // self.config.vae.vae_downsample_rate,
self.width // self.config.vae.vae_downsample_rate,
)
self.guidance_type = guidance_type_select(self.guidance_type, pag_guidance_scale, self.config.model.attn_type)
# 1. pre-compute negative embedding
if negative_prompt != "":
null_caption_token = self.tokenizer(
negative_prompt,
max_length=self.max_sequence_length,
padding="max_length",
truncation=True,
return_tensors="pt",
).to(self.device)
self.null_caption_embs = self.text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[
0
]
if prompt is None:
prompt = [""]
prompts = prompt if isinstance(prompt, list) else [prompt]
samples = []
for prompt in prompts:
# data prepare
prompts, hw, ar = (
[],
torch.tensor([[self.image_size, self.image_size]], dtype=torch.float, device=self.device).repeat(
num_images_per_prompt, 1
),
torch.tensor([[1.0]], device=self.device).repeat(num_images_per_prompt, 1),
)
for _ in range(num_images_per_prompt):
prompts.append(prepare_prompt_ar(prompt, self.base_ratios, device=self.device, show=False)[0].strip())
with torch.no_grad():
# prepare text feature
if not self.config.text_encoder.chi_prompt:
max_length_all = self.config.text_encoder.model_max_length
prompts_all = prompts
else:
chi_prompt = "\n".join(self.config.text_encoder.chi_prompt)
prompts_all = [chi_prompt + prompt for prompt in prompts]
num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
max_length_all = (
num_chi_prompt_tokens + self.config.text_encoder.model_max_length - 2
) # magic number 2: [bos], [_]
caption_token = self.tokenizer(
prompts_all,
max_length=max_length_all,
padding="max_length",
truncation=True,
return_tensors="pt",
).to(device=self.device)
select_index = [0] + list(range(-self.config.text_encoder.model_max_length + 1, 0))
caption_embs = self.text_encoder(caption_token.input_ids, caption_token.attention_mask)[0][:, None][
:, :, select_index
].to(self.weight_dtype)
emb_masks = caption_token.attention_mask[:, select_index]
null_y = self.null_caption_embs.repeat(len(prompts), 1, 1)[:, None].to(self.weight_dtype)
n = len(prompts)
if latents is None:
z = torch.randn(
n,
self.config.vae.vae_latent_dim,
self.latent_size_h,
self.latent_size_w,
generator=generator,
device=self.device,
)
else:
z = latents.to(self.device)
scheduler = self.prepare_scheduler(caption_embs, null_y, guidance_scale, pag_guidance_scale, hw=hw, ar=ar, emb_masks=emb_masks)
sample = scheduler.sample(
z,
steps=num_inference_steps,
order=2,
skip_type="time_uniform_flow",
method="multistep",
flow_shift=self.flow_shift,
)
sample = sample.to(self.vae_dtype)
with torch.no_grad():
sample = vae_decode(self.config.vae.vae_type, self.vae, sample)
if use_resolution_binning:
sample = resize_and_crop_tensor(sample, self.ori_width, self.ori_height)
samples.append(sample)
return samples
if __name__ == '__main__':
from torchvision.utils import save_image
parser = argparse.ArgumentParser(description="Generate images using DPMInversePipeline.")
parser.add_argument("--src_prompt", type=str, default="a yellow cat, frontal view, eye-level elevation, no tilt.",
help="Source text prompt for image generation.")
parser.add_argument("--tgt_prompt", type=str, default="a yellow cat, side view, eye-level elevation, no tilt.",
help="Target text prompt for image editing.")
parser.add_argument("--negative_prompt", type=str, default="", help="Negative text prompt for image generation.")
parser.add_argument("--height", type=int, default=1024, help="Height of the generated image.")
parser.add_argument("--width", type=int, default=1024, help="Width of the generated image.")
parser.add_argument("--guidance_scale", type=float, default=4.5, help="Guidance scale for the pipeline.")
parser.add_argument("--pag_guidance_scale", type=float, default=1.0, help="PAG guidance scale for the pipeline.")
parser.add_argument("--num_inference_steps", type=int, default=20, help="Number of inference steps.")
parser.add_argument("--num_images_per_prompt", type=int, default=2, help="Number of images to generate per prompt.")
parser.add_argument("--num_inversion_steps", type=int, default=5, help="Number of inversion steps for image editing.")
parser.add_argument("--config_path", type=str,
default="configs/sana1-5_config/1024ms/Sana_1600M_1024px_allqknorm_bf16_lr2e5.yaml",
help="Path to the model configuration file.")
parser.add_argument("--from_pretrained", type=str,
default="hf://Efficient-Large-Model/SANA1.5_1.6B_1024px/checkpoints/SANA1.5_1.6B_1024px.pth",
help="Path to the pretrained model weights.")
parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility.") # Added seed argument
args = parser.parse_args()
# Replace spaces with underscores in the source prompt
sanitized_prompt = args.src_prompt.replace(" ", "_")
# Generate a unique folder name based on settings as a JSON string
settings = {
"src_prompt": args.src_prompt,
"tgt_prompt": args.tgt_prompt,
"negative_prompt": args.negative_prompt,
"config_path": args.config_path,
"from_pretrained": args.from_pretrained,
"height": args.height,
"width": args.width,
"guidance_scale": args.guidance_scale,
"pag_guidance_scale": args.pag_guidance_scale,
"num_inference_steps": args.num_inference_steps,
"num_images_per_prompt": args.num_images_per_prompt,
"num_inversion_steps": args.num_inversion_steps,
"seed": args.seed # Added seed to settings
}
settings_str = json.dumps(settings, sort_keys=True)
# Encode settings_str as a hash code
settings_hash = hashlib.md5(settings_str.encode()).hexdigest()
# Create output directory using settings_str as the folder name
output_dir = os.path.join("editinv", sanitized_prompt, settings_hash)
os.makedirs(output_dir, exist_ok=True)
# Output file paths
generated_file = os.path.join(output_dir, "sample.png")
edited_file = os.path.join(output_dir, "edited.png")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
generator = torch.Generator(device=device).manual_seed(args.seed) # Use user-configured seed
config_path = os.path.join(addpath, args.config_path)
sana = DPMInversePipeline(config_path)
sana.from_pretrained(args.from_pretrained)
# Generate images
images = sana(
prompt=args.src_prompt,
height=args.height,
width=args.width,
negative_prompt=args.negative_prompt,
guidance_scale=args.guidance_scale,
pag_guidance_scale=args.pag_guidance_scale,
num_inference_steps=args.num_inference_steps,
generator=generator,
num_images_per_prompt=args.num_images_per_prompt
)
print(f"Generated image shape: {images[0].shape}")
save_image(images[0], generated_file, nrow=1, normalize=True, value_range=(-1, 1))
print(f"Image saved to {generated_file}")
# Edit images
edited_images = sana.edit(
src_prompt=args.src_prompt,
tgt_prompt=args.tgt_prompt,
src_img=images,
height=args.height,
width=args.width,
negative_prompt=args.negative_prompt,
num_inversion_steps=args.num_inversion_steps,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
pag_guidance_scale=args.pag_guidance_scale,
generator=generator,
)
print(f"Edited image shape: {edited_images[0].shape}")
save_image(edited_images[0], edited_file, nrow=1, normalize=True, value_range=(-1, 1))
print(f"Edited image saved to {edited_file}")
The changes are:
- I customized the pipeline of sana to add a function for editing the image
- I customized the inverse function of the DPM-Solver to add flow_shift as an argument. It's worth noting that when I use
flow_shift=1orflow_shift=3, the noisy latents I got are the same (nan latent).
Thank you!
The sampling process in SANA is customized to facilitate flow-based sampling using DPM-Solver++ (which DiffEdit also use to achieve image inversion for editing). I would suggest to experiment with the encoding mechanism used in the implementation of DiffEdit method. You can also play around with the encoding ratio to see the balance between image reconstruction/generation.
I test with a simple src/edit prompt pair (with encoding ratio=0.6) and here is the result. Source prompt: "A basket of apples". Edit prompt: "A basket of oranges". Left to right (source image, reconstruct image with source prompt, edited image with edit prompt).
@KhoiDOO Please use the above @nttung1110 suggestion if it sounds correct.
Hi, the above inversion technique could also be applied for image in-painting task. Here are a few results tested on SANA-multistep in 40 inference steps (with in-painting mask overlay on source image). Hope it would help.
Very cool examples. Is it possible to add these inverse feature into our code base? @nttung1110
That's interesting, @nttung1110, about the inpainting results. I just wonder why the editing via inversion results are not good; the edited fruits are not oranges. There might be a gap since Diffedit uses DPMSolverSampler. I haven't read it, so I currently do not know the differences between plus and non-plus versions.
Very cool examples. Is it possible to add these inverse feature into our code base? @nttung1110
Sure, please let me know how to add these features into your code base? Should I post the code snippet here?
That's interesting, @nttung1110, about the inpainting results. I just wonder why the editing via inversion results are not good; the edited fruits are not oranges. There might be a gap since Diffedit uses DPMSolverSampler. I haven't read it, so I currently do not know the differences between plus and non-plus versions.
In short, both DPM-Solver and its extended version (DPM-Solver ++) are both high-order solver for fast sampling compared to DDIM (first-order solver with slower sampling process). The difference lies in the the guided sampling process when applying large guidance scale. The authors claim that DPM-Solver is not suitable and not effective in this case. Hence, they design an extended version which could enable large guidance scale for fast sampling process. Hope this would help.
And in the case of DiffEdit, I guess there is little difference between DPM-Solver and its plus plus version unless you want to set large guidance scale
Very cool examples. Is it possible to add these inverse feature into our code base? @nttung1110
Sure, please let me know how to add these features into your code base? Should I post the code snippet here?
I'm very interested in testing your code, I'm working in an img2img functionality with poor results
Very cool examples. Is it possible to add these inverse feature into our code base? @nttung1110
Sure, please let me know how to add these features into your code base? Should I post the code snippet here?
I'm very interested in testing your code, I'm working in an img2img functionality with poor results
Same! Dear @nttung1110, would you mind to share your implementation? :o
Very cool examples. Is it possible to add these inverse feature into our code base? @nttung1110
Sure, please let me know how to add these features into your code base? Should I post the code snippet here?
I'm very interested in testing your code, I'm working in an img2img functionality with poor results
Same! Dear @nttung1110, would you mind to share your implementation? :o
Hi @ChunChenLin ,
Sure, I would love to share mine. I'm waiting for the reply from author repo @lawrence-cj to add inpainting feature to the repo.
OH, could you please add a branch and push a PR for your inpainting feature code? @nttung1110
OH, could you please add a branch and push a PR for your inpainting feature code? @nttung1110
Hi @lawrence-cj, thanks for notifying me about that. I would try my best to push via a PR when I have spare time. Thanks!
Hi @lawrence-cj , sorry for the delay. I didn't have much time to refactor the code and make a pull request as integrated feature into your main repo. So I uploaded an unofficial implementation of image inpainting on both SANA and SANA-Sprint at https://github.com/nttung1110/SANA-Inpainting/tree/main. I will refactor the code and make a PR later. Hope it would help.
Hi @lawrence-cj , sorry for the delay. I didn't have much time to refactor the code and make a pull request as integrated feature into your main repo. So I uploaded an unofficial implementation of image inpainting on both SANA and SANA-Sprint at https://github.com/nttung1110/SANA-Inpainting/tree/main. I will refactor the code and make a PR later. Hope it would help.
Hi @nttung1110 , the link is 404, could you update a new one?
Hi @lawrence-cj thanks for the great work on SANA ! I created a PR for this for standard SANA (not sprint model) here https://github.com/NVlabs/Sana/pull/296. Happy to make some changes to it as well if this is useful