stable-diffusion icon indicating copy to clipboard operation
stable-diffusion copied to clipboard

Inpainted images are degraded in quality

Open TheEnhas opened this issue 3 years ago • 7 comments

If I inpaint an image, and then use that image as a source for more inpainting (mostly because inpainting rarely gets things right on the first try, ie. it's better to inpaint and get some of what you want, save the image and then use that as a base to inpaint again and repeat until you're done), the entire image ends up being compressed somehow and it gets worse with every iteration. Using .png mode too, by the way.

After some testing it looks like inpainted images are degraded in general for some reason, even after a single run (but it gets more noticeable if you keep using inpainted images as a source). No idea why.

TheEnhas avatar Aug 29 '22 23:08 TheEnhas

For me, it becomes more normalized as you increase ddim steps. Starts off with too much contrast and blurry but clears up.

GucciFlipFlops1917 avatar Aug 30 '22 22:08 GucciFlipFlops1917

I'm not just talking about the inpainted part though, the entire image for some reason degrades no matter what I set steps at. Maybe it's saving as jpg when it's not supposed to.

TheEnhas avatar Aug 30 '22 22:08 TheEnhas

Ah, gotcha. I think it may be the encoding-decoding process, with degradation coming from converting the image to a latent space representation and back to a viewable bitmap-derivative. Still a learner, so that's more so conjecture.

GucciFlipFlops1917 avatar Aug 30 '22 22:08 GucciFlipFlops1917

I've tried inpainting on hlky as well and get the same problem so it's not exclusive to just optimized's interpretation.

Edit: Someone else is having the same issue, it's apparently because the entire image goes through the final sampler pass and not just the masked part: https://github.com/hlky/stable-diffusion/issues/153

TheEnhas avatar Aug 31 '22 01:08 TheEnhas

Just to be clear this only started happening with a recent change on hlky and did not work like that before with there im2img based "inpainting" implementation

TingTingin avatar Aug 31 '22 15:08 TingTingin

You can paste the original image back onto your results for a quick and dirty fix. Here's the optimized version of the script. I don't use gradio, so you'll have to port it yourself. But it just stores the image right after you load it, stores the mask at full resolution, and then after your gpu compute, lerps the original image based on the mask back onto your results, right before you save everything back to a file. This eliminated all loss for me. It may mess up your transitions if you use gradient (0,1) masks though.

import argparse, os, re
import torch
import numpy as np
from random import randint
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange
from torchvision.utils import make_grid
import time
from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import contextmanager, nullcontext
from einops import rearrange, repeat
from ldm.util import instantiate_from_config
from optimUtils import split_weighted_subprompts, logger
from transformers import logging
import pandas as pd
logging.set_verbosity_error()


def chunk(it, size):
    it = iter(it)
    return iter(lambda: tuple(islice(it, size)), ())


def load_model_from_config(ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    return sd


def load_img(path, h0, w0):

    image = Image.open(path).convert("RGB")
    w, h = image.size

    print(f"loaded input image of size ({w}, {h}) from {path}")
    if h0 is not None and w0 is not None:
        h, w = h0, w0

    w, h = map(lambda x: x - x % 64, (w, h))  # resize to integer multiple of 32

    print(f"New image size ({w}, {h})")
    image = image.resize((w, h), resample=Image.LANCZOS)
    raw_image = np.array(image)
    image = raw_image.astype(np.float32) / 255.0
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
    return 2.0 * image - 1.0, raw_image

def load_mask(path, h0, w0, invert=False):
   
    image = Image.open(path).convert("RGB")
    w, h = image.size
    print(f"loaded input image of size ({w}, {h})")   
    if(h0 is not None and w0 is not None):
        h, w = h0, w0
    
    w, h = map(lambda x: x - x % 64, (w, h))  # resize to integer multiple of 32

    print(f"New image size ({w}, {h})")
    raw_image = np.array(image).astype(np.float32) / 255.0
    image = image.resize((64, 64), resample = Image.LANCZOS)
    image = np.array(image)

    if invert:
        print("inverted")
        where_0, where_1 = np.where(image == 0),np.where(image == 255)
        image[where_0], image[where_1] = 255, 0
    image = image.astype(np.float32)/255.0 
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
    return (image,raw_image)

config = "optimizedSD/v1-inference.yaml"
ckpt = "models/ldm/stable-diffusion-v1/model.ckpt"

parser = argparse.ArgumentParser()

parser.add_argument(
    "--prompt", type=str, nargs="?", default="a painting of a virus monster playing guitar", help="the prompt to render"
)
parser.add_argument("--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/img2img-samples")
parser.add_argument("--init-img", type=str, nargs="?", help="path to the input image")

parser.add_argument(
    "--skip_grid",
    action="store_true",
    help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
)
parser.add_argument(
    "--skip_save",
    action="store_true",
    help="do not save individual samples. For speed measurements.",
)
parser.add_argument(
    "--ddim_steps",
    type=int,
    default=50,
    help="number of ddim sampling steps",
)

parser.add_argument(
    "--ddim_eta",
    type=float,
    default=0.0,
    help="ddim eta (eta=0.0 corresponds to deterministic sampling",
)
parser.add_argument(
    "--n_iter",
    type=int,
    default=1,
    help="sample this often",
)
parser.add_argument(
    "--H",
    type=int,
    default=None,
    help="image height, in pixel space",
)
parser.add_argument(
    "--W",
    type=int,
    default=None,
    help="image width, in pixel space",
)
parser.add_argument(
    "--strength",
    type=float,
    default=0.75,
    help="strength for noising/unnoising. 1.0 corresponds to full destruction of information in init image",
)
parser.add_argument(
    "--n_samples",
    type=int,
    default=5,
    help="how many samples to produce for each given prompt. A.k.a. batch size",
)
parser.add_argument(
    "--n_rows",
    type=int,
    default=0,
    help="rows in the grid (default: n_samples)",
)
parser.add_argument(
    "--scale",
    type=float,
    default=7.5,
    help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
)
parser.add_argument(
    "--from-file",
    type=str,
    help="if specified, load prompts from this file",
)
parser.add_argument(
    "--mask",
    type=str,
    help="if specified, load mask from this file",
)
parser.add_argument(
    "--dont_fix_mask_degradation",
    action="store_false",
    help="Fixes image degradation outside of masked regions by readding the original image on top of the results.",
    dest='fix_mask_degradation'
)
parser.add_argument(
    "--seed",
    type=int,
    default=None,
    help="the seed (for reproducible sampling)",
)
parser.add_argument(
    "--device",
    type=str,
    default="cuda",
    help="CPU or GPU (cuda/cuda:0/cuda:1/...)",
)
parser.add_argument(
    "--unet_bs",
    type=int,
    default=1,
    help="Slightly reduces inference time at the expense of high VRAM (value > 1 not recommended )",
)
parser.add_argument(
    "--turbo",
    action="store_true",
    help="Reduces inference time on the expense of 1GB VRAM",
)
parser.add_argument(
    "--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast"
)
parser.add_argument(
    "--format",
    type=str,
    help="output image format",
    choices=["jpg", "png"],
    default="png",
)
parser.add_argument(
    "--sampler",
    type=str,
    help="sampler",
    choices=["ddim"],
    default="ddim",
)
opt = parser.parse_args()

tic = time.time()
os.makedirs(opt.outdir, exist_ok=True)
outpath = opt.outdir
grid_count = len(os.listdir(outpath)) - 1

if opt.seed == None:
    opt.seed = randint(0, 1000000)
seed_everything(opt.seed)

# Logging
logger(vars(opt), log_csv = "logs/img2img_logs.csv")

sd = load_model_from_config(f"{ckpt}")
li, lo = [], []
for key, value in sd.items():
    sp = key.split(".")
    if (sp[0]) == "model":
        if "input_blocks" in sp:
            li.append(key)
        elif "middle_block" in sp:
            li.append(key)
        elif "time_embed" in sp:
            li.append(key)
        else:
            lo.append(key)
for key in li:
    sd["model1." + key[6:]] = sd.pop(key)
for key in lo:
    sd["model2." + key[6:]] = sd.pop(key)

config = OmegaConf.load(f"{config}")

assert os.path.isfile(opt.init_img)
(init_image, raw_image) = load_img(opt.init_img, opt.H, opt.W)
init_image = init_image.to(opt.device)


mask : torch.Tensor = None
raw_mask = None
if(opt.mask and os.path.isfile(opt.mask)):
    (mask,raw_mask) = load_mask(opt.mask, opt.H, opt.W)
    mask = mask.to(opt.device)

model = instantiate_from_config(config.modelUNet)
_, _ = model.load_state_dict(sd, strict=False)
model.eval()
model.cdevice = opt.device
model.unet_bs = opt.unet_bs
model.turbo = opt.turbo

modelCS = instantiate_from_config(config.modelCondStage)
_, _ = modelCS.load_state_dict(sd, strict=False)
modelCS.eval()
modelCS.cond_stage_model.device = opt.device

modelFS = instantiate_from_config(config.modelFirstStage)
_, _ = modelFS.load_state_dict(sd, strict=False)
modelFS.eval()
del sd
if opt.device != "cpu" and opt.precision == "autocast":
    model.half()
    modelCS.half()
    modelFS.half()
    init_image = init_image.half()
    if mask != None:
        mask.half()

batch_size = opt.n_samples

# What does this do?
if mask != None:
    mask = mask[0][0].unsqueeze(0).repeat(4,1,1).unsqueeze(0)
    mask = repeat(mask, '1 ... -> b ...', b=batch_size)

n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
if not opt.from_file:
    prompt = opt.prompt
    assert prompt is not None
    data = [batch_size * [prompt]]

else:
    print(f"reading prompts from {opt.from_file}")
    with open(opt.from_file, "r") as f:
        data = f.read().splitlines()
        data = batch_size * list(data)
        data = list(chunk(sorted(data), batch_size))

modelFS.to(opt.device)

init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image))  # move to latent space

if opt.device != "cpu":
    mem = torch.cuda.memory_allocated() / 1e6
    modelFS.to("cpu")
    while torch.cuda.memory_allocated() / 1e6 >= mem:
        time.sleep(1)


assert 0.0 <= opt.strength <= 1.0, "can only work with strength in [0.0, 1.0]"
t_enc = int(opt.strength * opt.ddim_steps)
print(f"target t_enc is {t_enc} steps")


if opt.precision == "autocast" and opt.device != "cpu":
    precision_scope = autocast
else:
    precision_scope = nullcontext

seeds = ""
with torch.no_grad():

    all_samples = list()
    for n in trange(opt.n_iter, desc="Sampling"):
        for prompts in tqdm(data, desc="data"):

            sample_path = os.path.join(outpath, "_".join(re.split(":| ", prompts[0])))[:150]
            os.makedirs(sample_path, exist_ok=True)
            base_count = len(os.listdir(sample_path))

            with precision_scope("cuda"):
                modelCS.to(opt.device)
                uc = None
                if opt.scale != 1.0:
                    uc = modelCS.get_learned_conditioning(batch_size * [""])
                if isinstance(prompts, tuple):
                    prompts = list(prompts)

                subprompts, weights = split_weighted_subprompts(prompts[0])
                if len(subprompts) > 1:
                    c = torch.zeros_like(uc)
                    totalWeight = sum(weights)
                    # normalize each "sub prompt" and add it
                    for i in range(len(subprompts)):
                        weight = weights[i]
                        # if not skip_normalize:
                        weight = weight / totalWeight
                        c = torch.add(c, modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
                else:
                    c = modelCS.get_learned_conditioning(prompts)

                if opt.device != "cpu":
                    mem = torch.cuda.memory_allocated() / 1e6
                    modelCS.to("cpu")
                    while torch.cuda.memory_allocated() / 1e6 >= mem:
                        time.sleep(1)

                # encode (scaled latent)
                z_enc = model.stochastic_encode(
                    init_latent,
                    torch.tensor([t_enc] * batch_size).to(opt.device),
                    opt.seed,
                    opt.ddim_eta,
                    opt.ddim_steps,
                )
                
                # decode it
                samples_ddim = model.sample(
                    t_enc,
                    c,
                    z_enc,
                    unconditional_guidance_scale=opt.scale,
                    unconditional_conditioning=uc,
                    mask = mask,
                    x_T = init_latent,
                    sampler = opt.sampler
                )
                
                modelFS.to(opt.device)
                print("saving images")
                for i in range(batch_size):

                    x_samples_ddim = modelFS.decode_first_stage(samples_ddim[i].unsqueeze(0))
                    x_sample = 255.0 * torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)                    
                    x_sample = rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c") 
                    
                    # Bugfix : Masked areas degrade in quality over the iterations. Lerp the original
                    # image on top of the results to revert this.
                    if mask != None and opt.fix_mask_degradation:
                        x_sample = (x_sample * (1.0 - raw_mask)  + raw_image * raw_mask)
                    x_sample = x_sample.astype(np.uint8) 
                    
                    # Input Image
                    Image.fromarray(x_sample).save(
                        os.path.join(sample_path, "seed_" + str(opt.seed) + "_" + f"{base_count:05}.{opt.format}")
                    )
                    seeds += str(opt.seed) + ","
                    opt.seed += 1
                    base_count += 1

                if opt.device != "cpu":
                    mem = torch.cuda.memory_allocated() / 1e6
                    modelFS.to("cpu")
                    while torch.cuda.memory_allocated() / 1e6 >= mem:
                        time.sleep(1)

                del samples_ddim
                print("memory_final = ", torch.cuda.memory_allocated() / 1e6)

toc = time.time()

time_taken = (toc - tic) / 60.0

print(
    (
        "Samples finished in {0:.2f} minutes and exported to "
        + sample_path
        + "\n Seeds used = "
        + seeds[:-1]
    ).format(time_taken)
)```

kiriri avatar Sep 01 '22 12:09 kiriri

@kiriri The script you provide works well for 512x512 images (init-img and mask), although it appears the masked areas must be inverted compared to the original optimized inpainting gradio script to achieve the same effect.

But if I try any other resolution (divisible by 64) for init-img and mask, I get an error. For example, running with 512x1024, I get: RuntimeError: The size of tensor a (128) must match the size of tensor b (64) at non-singleton dimension 2

prettydeep avatar Sep 08 '22 13:09 prettydeep