Sana icon indicating copy to clipboard operation
Sana copied to clipboard

Image Editting via Inversion

Open KhoiDOO opened this issue 7 months ago • 9 comments

Hi, please correct me if I'm wrong. I tried using the inverse function in DPM-Solver to invert the source latent to the noisy latent. After obtaining the noisy latent, I use the sample function to get the edited image. However, when I use the inverse function, the noisy latent I got is all of nan value. I leave the code below, please have a check.

import argparse
import torch
import sys
import os
import hashlib
import json
addpath = os.path.join('/'.join(os.path.dirname(os.path.abspath(__file__)).split('/')[:-1]), 'submodule/Sana')
sys.path.append(addpath)

from torch import Tensor

from app.sana_pipeline import SanaPipeline, classify_height_width_bin, guidance_type_select
from diffusion.data.datasets.utils import (
    ASPECT_RATIO_512_TEST,
    ASPECT_RATIO_1024_TEST,
    ASPECT_RATIO_2048_TEST,
    ASPECT_RATIO_4096_TEST,
)
from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_encode, vae_decode
from diffusion.model.utils import get_weight_dtype, prepare_prompt_ar, resize_and_crop_tensor
from diffusion.utils.config import SanaConfig, model_init_config
from diffusion.utils.logger import get_root_logger

from diffusion.model import gaussian_diffusion as gd
from diffusion.model.dpm_solver import DPM_Solver, NoiseScheduleFlow, NoiseScheduleVP, model_wrapper


class CustomDPM_Solver(DPM_Solver):
    def __init__(
        self,
        model_fn,
        noise_schedule,
        algorithm_type="dpmsolver++",
        correcting_x0_fn=None,
        correcting_xt_fn=None,
        thresholding_max_val=1.0,
        dynamic_thresholding_ratio=0.995,
    ):
        super().__init__(
            model_fn,
            noise_schedule,
            algorithm_type=algorithm_type,
            correcting_x0_fn=correcting_x0_fn,
            correcting_xt_fn=correcting_xt_fn,
            thresholding_max_val=thresholding_max_val,
            dynamic_thresholding_ratio=dynamic_thresholding_ratio,
        )
    
    def inverse(
        self,
        x,
        steps=20,
        t_start=None,
        t_end=None,
        order=2,
        skip_type="time_uniform",
        method="multistep",
        lower_order_final=True,
        denoise_to_zero=False,
        solver_type="dpmsolver",
        atol=0.0078,
        rtol=0.05,
        return_intermediate=False,
        flow_shift=1.0,
    ):
        """
        Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver.
        For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training.
        """
        t_0 = 1.0 / self.noise_schedule.total_N if t_start is None else t_start
        t_T = self.noise_schedule.T if t_end is None else t_end
        assert (
            t_0 > 0 and t_T > 0
        ), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
        return self.sample(
            x,
            steps=steps,
            t_start=t_0,
            t_end=t_T,
            order=order,
            skip_type=skip_type,
            method=method,
            lower_order_final=lower_order_final,
            denoise_to_zero=denoise_to_zero,
            solver_type=solver_type,
            atol=atol,
            rtol=rtol,
            return_intermediate=return_intermediate,
            flow_shift=flow_shift,
        )

def DPMS(
    model,
    condition,
    uncondition,
    cfg_scale,
    pag_scale=1.0,
    pag_applied_layers=None,
    model_type="noise",  # or "x_start" or "v" or "score", "flow"
    noise_schedule="linear",
    guidance_type="classifier-free",
    model_kwargs=None,
    diffusion_steps=1000,
    schedule="VP",
    interval_guidance=None,
):
    if pag_applied_layers is None:
        pag_applied_layers = []
    if model_kwargs is None:
        model_kwargs = {}
    if interval_guidance is None:
        interval_guidance = [0, 1.0]
    betas = torch.tensor(gd.get_named_beta_schedule(noise_schedule, diffusion_steps))

    ## 1. Define the noise schedule.
    if schedule == "VP":
        noise_schedule = NoiseScheduleVP(schedule="discrete", betas=betas)
    elif schedule == "FLOW":
        noise_schedule = NoiseScheduleFlow(schedule="discrete_flow")

    ## 2. Convert your discrete-time `model` to the continuous-time
    ## noise prediction model. Here is an example for a diffusion model
    ## `model` with the noise prediction type ("noise") .
    model_fn = model_wrapper(
        model,
        noise_schedule,
        model_type=model_type,
        model_kwargs=model_kwargs,
        guidance_type=guidance_type,
        pag_scale=pag_scale,
        pag_applied_layers=pag_applied_layers,
        condition=condition,
        unconditional_condition=uncondition,
        guidance_scale=cfg_scale,
        interval_guidance=interval_guidance,
    )
    ## 3. Define dpm-solver and sample by multistep DPM-Solver.
    return CustomDPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")


class DPMInversePipeline(SanaPipeline):
    def __init__(self, config_path):
        super().__init__(config_path)
    
    @torch.inference_mode()
    def prepare_prompt(self, prompts):
        if not self.config.text_encoder.chi_prompt:
            max_length_all = self.config.text_encoder.model_max_length
            prompts_all = prompts
        else:
            chi_prompt = "\n".join(self.config.text_encoder.chi_prompt)
            prompts_all = [chi_prompt + prompt for prompt in prompts]
            num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
            max_length_all = (
                num_chi_prompt_tokens + self.config.text_encoder.model_max_length - 2
            )  # magic number 2: [bos], [_]

        caption_token = self.tokenizer(
            prompts_all,
            max_length=max_length_all,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        ).to(device=self.device)
        select_index = [0] + list(range(-self.config.text_encoder.model_max_length + 1, 0))
        caption_embs = self.text_encoder(caption_token.input_ids, caption_token.attention_mask)[0][:, None][
            :, :, select_index
        ].to(self.weight_dtype)
        emb_masks = caption_token.attention_mask[:, select_index]

        return caption_embs, emb_masks
    
    @torch.inference_mode()
    def prepare_scheduler(self, caption_embs, null_y, guidance_scale, pag_guidance_scale, hw, ar, emb_masks):
        model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks)
        if self.vis_sampler == "flow_euler":
            raise NotImplementedError("Flow Euler is not supported for editing.")
        elif self.vis_sampler == "flow_dpm-solver":
            scheduler = DPMS(
                self.model,
                condition=caption_embs,
                uncondition=null_y,
                guidance_type=self.guidance_type,
                cfg_scale=guidance_scale,
                pag_scale=pag_guidance_scale,
                pag_applied_layers=self.config.model.pag_applied_layers,
                model_type="flow",
                model_kwargs=model_kwargs,
                schedule="FLOW",
            )
            scheduler.register_progress_bar(self.progress_fn)
            return scheduler
        else:
            raise ValueError(f"Unsupported sampler: {self.vis_sampler}")
        
    @torch.inference_mode()
    def edit(
        self,
        src_prompt: list | str = None,
        tgt_prompt: list | str =None,
        src_img: list[Tensor] = None,
        height=1024,
        width=1024,
        negative_prompt="",
        num_inversion_steps=5,
        num_inference_steps=20,
        guidance_scale=4.5,
        pag_guidance_scale=1.0,
        generator=torch.Generator().manual_seed(42),
        use_resolution_binning=True,
    ):
        self.ori_height, self.ori_width = height, width
        if use_resolution_binning:
            self.height, self.width = classify_height_width_bin(height, width, ratios=self.base_ratios)
        else:
            self.height, self.width = height, width
        self.latent_size_h, self.latent_size_w = (
            self.height // self.config.vae.vae_downsample_rate,
            self.width // self.config.vae.vae_downsample_rate,
        )
        self.guidance_type = guidance_type_select(self.guidance_type, pag_guidance_scale, self.config.model.attn_type)

        # 1. pre-compute negative embedding
        if negative_prompt != "":
            null_caption_token = self.tokenizer(
                negative_prompt,
                max_length=self.max_sequence_length,
                padding="max_length",
                truncation=True,
                return_tensors="pt",
            ).to(self.device)
            self.null_caption_embs = self.text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[0]

        if src_prompt is None or tgt_prompt is None or src_img is None:
            raise ValueError("src_prompt, tgt_prompt and src_img must be provided.")
        src_prompts = src_prompt if isinstance(src_prompt, list) else [src_prompt]
        tgt_prompts = tgt_prompt if isinstance(tgt_prompt, list) else [tgt_prompt]
        src_imgs = src_img if isinstance(src_img, list) else [src_img]
        samples = []

        for sprompt, tprompt, imgs in zip(src_prompts, tgt_prompts, src_imgs):
            # data prepare
            num_images_per_prompt = imgs.size(0)
            sprompts, tprompts, hw, ar = (
                [], [],
                torch.tensor([[self.image_size, self.image_size]], dtype=torch.float, device=self.device).repeat(num_images_per_prompt, 1),
                torch.tensor([[1.0]], device=self.device).repeat(num_images_per_prompt, 1),
            )

            for _ in range(num_images_per_prompt):
                sprompts.append(prepare_prompt_ar(sprompt, self.base_ratios, device=self.device, show=False)[0].strip())
                tprompts.append(prepare_prompt_ar(tprompt, self.base_ratios, device=self.device, show=False)[0].strip())

            with torch.no_grad():
                # prepare text feature
                src_caption_embs, scr_emb_masks = self.prepare_prompt(sprompts)
                tgt_caption_embs, tgt_emb_masks = self.prepare_prompt(tprompts)
                
                null_y = self.null_caption_embs.repeat(len(sprompts), 1, 1)[:, None].to(self.weight_dtype)

                # inversion step
                scheduler = self.prepare_scheduler(src_caption_embs, null_y, guidance_scale, pag_guidance_scale, hw=hw, ar=ar, emb_masks=scr_emb_masks)
                latent = vae_encode(self.config.vae.vae_type, self.vae, imgs, False, self.device)
                noisy_latent = scheduler.inverse(
                    x = latent, 
                    steps=num_inversion_steps, 
                    order=2, 
                    skip_type="time_uniform_flow", 
                    method="multistep",
                    flow_shift=self.flow_shift,
                )
                print(noisy_latent.max(), noisy_latent.min(), noisy_latent.mean(), noisy_latent.shape)
                
                # sampling
                scheduler = self.prepare_scheduler(tgt_caption_embs, null_y, guidance_scale, pag_guidance_scale, hw=hw, ar=ar, emb_masks=tgt_emb_masks)
                sample = scheduler.sample(
                    noisy_latent, 
                    steps=num_inference_steps, 
                    order=2, 
                    skip_type="time_uniform_flow", 
                    method="multistep", 
                    flow_shift=self.flow_shift
                )
                    
            sample = sample.to(self.vae_dtype)
            with torch.no_grad():
                sample = vae_decode(self.config.vae.vae_type, self.vae, sample)

            if use_resolution_binning:
                sample = resize_and_crop_tensor(sample, self.ori_width, self.ori_height)
            samples.append(sample)

        return samples
    
    @torch.inference_mode()
    def forward(
        self,
        prompt=None,
        height=1024,
        width=1024,
        negative_prompt="",
        num_inference_steps=20,
        guidance_scale=4.5,
        pag_guidance_scale=1.0,
        num_images_per_prompt=1,
        generator=torch.Generator().manual_seed(42),
        latents=None,
        use_resolution_binning=True,
    ):
        self.ori_height, self.ori_width = height, width
        if use_resolution_binning:
            self.height, self.width = classify_height_width_bin(height, width, ratios=self.base_ratios)
        else:
            self.height, self.width = height, width
        self.latent_size_h, self.latent_size_w = (
            self.height // self.config.vae.vae_downsample_rate,
            self.width // self.config.vae.vae_downsample_rate,
        )
        self.guidance_type = guidance_type_select(self.guidance_type, pag_guidance_scale, self.config.model.attn_type)

        # 1. pre-compute negative embedding
        if negative_prompt != "":
            null_caption_token = self.tokenizer(
                negative_prompt,
                max_length=self.max_sequence_length,
                padding="max_length",
                truncation=True,
                return_tensors="pt",
            ).to(self.device)
            self.null_caption_embs = self.text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[
                0
            ]

        if prompt is None:
            prompt = [""]
        prompts = prompt if isinstance(prompt, list) else [prompt]
        samples = []

        for prompt in prompts:
            # data prepare
            prompts, hw, ar = (
                [],
                torch.tensor([[self.image_size, self.image_size]], dtype=torch.float, device=self.device).repeat(
                    num_images_per_prompt, 1
                ),
                torch.tensor([[1.0]], device=self.device).repeat(num_images_per_prompt, 1),
            )

            for _ in range(num_images_per_prompt):
                prompts.append(prepare_prompt_ar(prompt, self.base_ratios, device=self.device, show=False)[0].strip())

            with torch.no_grad():
                # prepare text feature
                if not self.config.text_encoder.chi_prompt:
                    max_length_all = self.config.text_encoder.model_max_length
                    prompts_all = prompts
                else:
                    chi_prompt = "\n".join(self.config.text_encoder.chi_prompt)
                    prompts_all = [chi_prompt + prompt for prompt in prompts]
                    num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
                    max_length_all = (
                        num_chi_prompt_tokens + self.config.text_encoder.model_max_length - 2
                    )  # magic number 2: [bos], [_]

                caption_token = self.tokenizer(
                    prompts_all,
                    max_length=max_length_all,
                    padding="max_length",
                    truncation=True,
                    return_tensors="pt",
                ).to(device=self.device)
                select_index = [0] + list(range(-self.config.text_encoder.model_max_length + 1, 0))
                caption_embs = self.text_encoder(caption_token.input_ids, caption_token.attention_mask)[0][:, None][
                    :, :, select_index
                ].to(self.weight_dtype)
                emb_masks = caption_token.attention_mask[:, select_index]
                null_y = self.null_caption_embs.repeat(len(prompts), 1, 1)[:, None].to(self.weight_dtype)

                n = len(prompts)
                if latents is None:
                    z = torch.randn(
                        n,
                        self.config.vae.vae_latent_dim,
                        self.latent_size_h,
                        self.latent_size_w,
                        generator=generator,
                        device=self.device,
                    )
                else:
                    z = latents.to(self.device)
                scheduler = self.prepare_scheduler(caption_embs, null_y, guidance_scale, pag_guidance_scale, hw=hw, ar=ar, emb_masks=emb_masks)
                sample = scheduler.sample(
                    z,
                    steps=num_inference_steps,
                    order=2,
                    skip_type="time_uniform_flow",
                    method="multistep",
                    flow_shift=self.flow_shift,
                )   

            sample = sample.to(self.vae_dtype)
            with torch.no_grad():
                sample = vae_decode(self.config.vae.vae_type, self.vae, sample)

            if use_resolution_binning:
                sample = resize_and_crop_tensor(sample, self.ori_width, self.ori_height)
            samples.append(sample)

        return samples
        

if __name__ == '__main__':

    from torchvision.utils import save_image

    parser = argparse.ArgumentParser(description="Generate images using DPMInversePipeline.")
    parser.add_argument("--src_prompt", type=str, default="a yellow cat, frontal view, eye-level elevation, no tilt.", 
                        help="Source text prompt for image generation.")
    parser.add_argument("--tgt_prompt", type=str, default="a yellow cat, side view, eye-level elevation, no tilt.", 
                        help="Target text prompt for image editing.")
    parser.add_argument("--negative_prompt", type=str, default="", help="Negative text prompt for image generation.")
    parser.add_argument("--height", type=int, default=1024, help="Height of the generated image.")
    parser.add_argument("--width", type=int, default=1024, help="Width of the generated image.")
    parser.add_argument("--guidance_scale", type=float, default=4.5, help="Guidance scale for the pipeline.")
    parser.add_argument("--pag_guidance_scale", type=float, default=1.0, help="PAG guidance scale for the pipeline.")
    parser.add_argument("--num_inference_steps", type=int, default=20, help="Number of inference steps.")
    parser.add_argument("--num_images_per_prompt", type=int, default=2, help="Number of images to generate per prompt.")
    parser.add_argument("--num_inversion_steps", type=int, default=5, help="Number of inversion steps for image editing.")
    parser.add_argument("--config_path", type=str, 
                        default="configs/sana1-5_config/1024ms/Sana_1600M_1024px_allqknorm_bf16_lr2e5.yaml", 
                        help="Path to the model configuration file.")
    parser.add_argument("--from_pretrained", type=str, 
                        default="hf://Efficient-Large-Model/SANA1.5_1.6B_1024px/checkpoints/SANA1.5_1.6B_1024px.pth", 
                        help="Path to the pretrained model weights.")
    parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility.")  # Added seed argument

    args = parser.parse_args()

    # Replace spaces with underscores in the source prompt
    sanitized_prompt = args.src_prompt.replace(" ", "_")

    # Generate a unique folder name based on settings as a JSON string
    settings = {
        "src_prompt": args.src_prompt,
        "tgt_prompt": args.tgt_prompt,
        "negative_prompt": args.negative_prompt,
        "config_path": args.config_path,
        "from_pretrained": args.from_pretrained,
        "height": args.height,
        "width": args.width,
        "guidance_scale": args.guidance_scale,
        "pag_guidance_scale": args.pag_guidance_scale,
        "num_inference_steps": args.num_inference_steps,
        "num_images_per_prompt": args.num_images_per_prompt,
        "num_inversion_steps": args.num_inversion_steps,
        "seed": args.seed  # Added seed to settings
    }
    settings_str = json.dumps(settings, sort_keys=True)

    # Encode settings_str as a hash code
    settings_hash = hashlib.md5(settings_str.encode()).hexdigest()

    # Create output directory using settings_str as the folder name
    output_dir = os.path.join("editinv", sanitized_prompt, settings_hash)
    os.makedirs(output_dir, exist_ok=True)

    # Output file paths
    generated_file = os.path.join(output_dir, "sample.png")
    edited_file = os.path.join(output_dir, "edited.png")

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    generator = torch.Generator(device=device).manual_seed(args.seed)  # Use user-configured seed

    config_path = os.path.join(addpath, args.config_path)

    sana = DPMInversePipeline(config_path)
    sana.from_pretrained(args.from_pretrained)

    # Generate images
    images = sana(
        prompt=args.src_prompt,
        height=args.height,
        width=args.width,
        negative_prompt=args.negative_prompt,
        guidance_scale=args.guidance_scale,
        pag_guidance_scale=args.pag_guidance_scale,
        num_inference_steps=args.num_inference_steps,
        generator=generator,
        num_images_per_prompt=args.num_images_per_prompt
    )

    print(f"Generated image shape: {images[0].shape}")
    save_image(images[0], generated_file, nrow=1, normalize=True, value_range=(-1, 1))
    print(f"Image saved to {generated_file}")

    # Edit images
    edited_images = sana.edit(
        src_prompt=args.src_prompt,
        tgt_prompt=args.tgt_prompt,
        src_img=images,
        height=args.height,
        width=args.width,
        negative_prompt=args.negative_prompt,
        num_inversion_steps=args.num_inversion_steps,
        num_inference_steps=args.num_inference_steps,
        guidance_scale=args.guidance_scale,
        pag_guidance_scale=args.pag_guidance_scale,
        generator=generator,
    )

    print(f"Edited image shape: {edited_images[0].shape}")
    save_image(edited_images[0], edited_file, nrow=1, normalize=True, value_range=(-1, 1))
    print(f"Edited image saved to {edited_file}")

The changes are:

  • I customized the pipeline of sana to add a function for editing the image
  • I customized the inverse function of the DPM-Solver to add flow_shift as an argument. It's worth noting that when I use flow_shift=1 or flow_shift=3, the noisy latents I got are the same (nan latent).

Thank you!

KhoiDOO avatar Apr 10 '25 11:04 KhoiDOO

The sampling process in SANA is customized to facilitate flow-based sampling using DPM-Solver++ (which DiffEdit also use to achieve image inversion for editing). I would suggest to experiment with the encoding mechanism used in the implementation of DiffEdit method. You can also play around with the encoding ratio to see the balance between image reconstruction/generation.

I test with a simple src/edit prompt pair (with encoding ratio=0.6) and here is the result. Source prompt: "A basket of apples". Edit prompt: "A basket of oranges". Left to right (source image, reconstruct image with source prompt, edited image with edit prompt).

Image

nttung1110 avatar Apr 30 '25 09:04 nttung1110

@KhoiDOO Please use the above @nttung1110 suggestion if it sounds correct.

lawrence-cj avatar May 08 '25 05:05 lawrence-cj

Hi, the above inversion technique could also be applied for image in-painting task. Here are a few results tested on SANA-multistep in 40 inference steps (with in-painting mask overlay on source image). Hope it would help.

Image

nttung1110 avatar May 12 '25 13:05 nttung1110

Very cool examples. Is it possible to add these inverse feature into our code base? @nttung1110

lawrence-cj avatar May 12 '25 15:05 lawrence-cj

That's interesting, @nttung1110, about the inpainting results. I just wonder why the editing via inversion results are not good; the edited fruits are not oranges. There might be a gap since Diffedit uses DPMSolverSampler. I haven't read it, so I currently do not know the differences between plus and non-plus versions.

KhoiDOO avatar May 12 '25 16:05 KhoiDOO

Very cool examples. Is it possible to add these inverse feature into our code base? @nttung1110

Sure, please let me know how to add these features into your code base? Should I post the code snippet here?

nttung1110 avatar May 13 '25 02:05 nttung1110

That's interesting, @nttung1110, about the inpainting results. I just wonder why the editing via inversion results are not good; the edited fruits are not oranges. There might be a gap since Diffedit uses DPMSolverSampler. I haven't read it, so I currently do not know the differences between plus and non-plus versions.

In short, both DPM-Solver and its extended version (DPM-Solver ++) are both high-order solver for fast sampling compared to DDIM (first-order solver with slower sampling process). The difference lies in the the guided sampling process when applying large guidance scale. The authors claim that DPM-Solver is not suitable and not effective in this case. Hence, they design an extended version which could enable large guidance scale for fast sampling process. Hope this would help.

nttung1110 avatar May 13 '25 04:05 nttung1110

And in the case of DiffEdit, I guess there is little difference between DPM-Solver and its plus plus version unless you want to set large guidance scale

nttung1110 avatar May 13 '25 04:05 nttung1110

Very cool examples. Is it possible to add these inverse feature into our code base? @nttung1110

Sure, please let me know how to add these features into your code base? Should I post the code snippet here?

I'm very interested in testing your code, I'm working in an img2img functionality with poor results

luca-saggese avatar May 17 '25 22:05 luca-saggese

Very cool examples. Is it possible to add these inverse feature into our code base? @nttung1110

Sure, please let me know how to add these features into your code base? Should I post the code snippet here?

I'm very interested in testing your code, I'm working in an img2img functionality with poor results

Same! Dear @nttung1110, would you mind to share your implementation? :o

ChunChenLin avatar Jun 04 '25 02:06 ChunChenLin

Very cool examples. Is it possible to add these inverse feature into our code base? @nttung1110

Sure, please let me know how to add these features into your code base? Should I post the code snippet here?

I'm very interested in testing your code, I'm working in an img2img functionality with poor results

Same! Dear @nttung1110, would you mind to share your implementation? :o

Hi @ChunChenLin ,

Sure, I would love to share mine. I'm waiting for the reply from author repo @lawrence-cj to add inpainting feature to the repo.

nttung1110 avatar Jul 01 '25 07:07 nttung1110

OH, could you please add a branch and push a PR for your inpainting feature code? @nttung1110

lawrence-cj avatar Jul 01 '25 07:07 lawrence-cj

OH, could you please add a branch and push a PR for your inpainting feature code? @nttung1110

Hi @lawrence-cj, thanks for notifying me about that. I would try my best to push via a PR when I have spare time. Thanks!

nttung1110 avatar Jul 10 '25 06:07 nttung1110

Hi @lawrence-cj , sorry for the delay. I didn't have much time to refactor the code and make a pull request as integrated feature into your main repo. So I uploaded an unofficial implementation of image inpainting on both SANA and SANA-Sprint at https://github.com/nttung1110/SANA-Inpainting/tree/main. I will refactor the code and make a PR later. Hope it would help.

nttung1110 avatar Jul 18 '25 08:07 nttung1110

Hi @lawrence-cj , sorry for the delay. I didn't have much time to refactor the code and make a pull request as integrated feature into your main repo. So I uploaded an unofficial implementation of image inpainting on both SANA and SANA-Sprint at https://github.com/nttung1110/SANA-Inpainting/tree/main. I will refactor the code and make a PR later. Hope it would help.

Hi @nttung1110 , the link is 404, could you update a new one?

micklexqg avatar Aug 17 '25 04:08 micklexqg

Hi @lawrence-cj thanks for the great work on SANA ! I created a PR for this for standard SANA (not sprint model) here https://github.com/NVlabs/Sana/pull/296. Happy to make some changes to it as well if this is useful

Cedric-Perauer avatar Aug 29 '25 16:08 Cedric-Perauer