k-diffusion icon indicating copy to clipboard operation
k-diffusion copied to clipboard

sample_dpmpp_2m has a bug?

Open hallatore opened this issue 1 year ago • 19 comments

Hi,

I've been playing around with the sample_dpmpp_2m sampling and found that swapping one variable changes/fixes blur. I don't know the math formula for this, so I might be wrong. But I think there might be a bug in the code? Let me know what you think. And if you want me to create a PR for it.

Here are my results

hallatore avatar Mar 10 '23 06:03 hallatore

Here is an example testing low steps.

xyz_grid-0023-3849107070

hallatore avatar Mar 10 '23 08:03 hallatore

Here is a version that works with DPM++ 2M. At least I seem to get pretty good results with it.

xyz_grid-0031-3849107065

And with "Always discard next-to-last sigma" turned OFF

xyz_grid-0030-3849107065

At 10 steps: https://imgsli.com/MTYxMjc5

@torch.no_grad()
def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None):
    """DPM-Solver++(2M)."""
    extra_args = {} if extra_args is None else extra_args
    s_in = x.new_ones([x.shape[0]])
    sigma_fn = lambda t: t.neg().exp()
    t_fn = lambda sigma: sigma.log().neg()
    old_denoised = None

    for i in trange(len(sigmas) - 1, disable=disable):
        denoised = model(x, sigmas[i] * s_in, **extra_args)
        if callback is not None:
            callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
        t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
        h = t_next - t

        if old_denoised is None or sigmas[i + 1] == 0:
            x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
        else:
            h_last = t - t_fn(sigmas[i - 1])

            h_min = min(h_last, h)
            h_max = max(h_last, h)
            r = h_max / h_min

            h_d = (h_max + h_min) / 2
            denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
            x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h_d).expm1() * denoised_d

        old_denoised = denoised
    return x

hallatore avatar Mar 10 '23 20:03 hallatore

Some tests on human faces

xyz_grid-0003-954420047

hallatore avatar Mar 11 '23 06:03 hallatore

will this get adressed or you just guys moved on?

2blackbar avatar Apr 01 '23 05:04 2blackbar

Any comment on this @crowsonkb? Are you still maintaining this repo or should we just fork it?

wywywywy avatar Apr 02 '23 09:04 wywywywy

@hallatore yes, please put up a PR for it; thank you!

pbaylies avatar Apr 09 '23 01:04 pbaylies

I need to look at this but I'm sick right now so it might be a few days

crowsonkb avatar Apr 09 '23 01:04 crowsonkb

@crowsonkb Take care! 🙏

ClashSAN avatar Apr 09 '23 22:04 ClashSAN

to aid understanding, here's what the diff looks like:

image

it changes the second-order step.

when we compute r: we no longer take the ratio "h_last over h".
instead: we computer as the ratio "the greater of (h_last, h)" over "the smaller of (h_last, h)".

when computing x: we no longer use an (-h).expm1() term, but rather replace h with "the average of (h_last, h)".

@hallatore does that sound like a correct description of your technique?
can you explain how you came to this algorithm? is it a more faithful implemention of the paper, or is it a novel idea?
is there a requirement that r be greater than 1? it seems that's the guarantee you're trying to create?

personally: looking at the before/after samples, I'm not convinced it's "better" or more "correct" — to me it looks "different".

Birch-san avatar Apr 10 '23 13:04 Birch-san

Just curious, is there any guarantee that h_min won't be zero, or close to zero, and therefore, should there be a check to make sure we have sane values of r? Judging from the samples here, it does seem that this change can help, in practice.

pbaylies avatar Apr 10 '23 13:04 pbaylies

A lot of the code takes for granted that h_last is always lower than h. When that is true we get a factor between 1..0. But when this is wrong we get a value above 1. I'm not sure if min/max-ing is the right approach to fixing this, but we do want to never have an r value above 1.

The other change is that h from the "current step" while denoised_d is a value between current and last step based on r. I think it makes sense that if denoised_d is a computed factor between current and last step, then h should also be computed from the same factor. Otherwise you use the current steps denoising factor on the computed denoised_d. Here i'm also unsure if the average is a good fit.

So to sum it up the changes try to address two things.

  1. In some edge cases the h_last value can be higher than h, which causes the r factor to be above 1.
  2. When multiplying h with denoised_d we use a current-step value with a computed last/current-step value. Which i'm unsure is a good way to do this.

hallatore avatar Apr 13 '23 06:04 hallatore

If you never want an r value above one, then I'd say set that as a max; clamp the range, make sure it's in the valid range you want. And see if you can find some pathological test cases for it!

pbaylies avatar Apr 21 '23 03:04 pbaylies

If you never want an r value above one, then I'd say set that as a max; clamp the range, make sure it's in the valid range you want. And see if you can find some pathological test cases for it!

Just tried r = torch.clamp(r, max=1.0), and the result is different. Not sure if it's better or worse.

Female warrior, 10 steps, seed 3013411575 without clamp

image

With clamp

image

When multiplying h with denoised_d we use a current-step value with a computed last/current-step value. Which i'm unsure is a good way to do this.

But the value of h is already h = t_next - t. Pardon my ignorance, I still don't understand why we should average it?

wywywywy avatar Apr 22 '23 12:04 wywywywy

Yes, both of those look good to me...

pbaylies avatar Apr 22 '23 12:04 pbaylies

It looks to me like all you are seeing is faster convergence due to loss of detail. You get fewer artifacts with low numbers of steps, but the final converged image has significantly less detail, no matter how many steps you use.

The overall effect on generated images is like a soft focus or smearing vaseline on the lens, like what Star Trek would do every time a woman was on screen. It might look better in some instances, (particularly in closeups of women like the examples posted thus far), but it definitely isn't an overall improvement, this is very obvious in images of space/star fields and other noisy images.

In the following comparisons, "DPM++ 2M Test" is the modified function posted earlier in the thread, the loss of detail is extremely obvious. "DPM++ 2M TestX" is an altered version that removes the "h_d" change that averaged h & h_last, which made no sense to me. It isn't as bad, but still shows a loss of detail vs the original implementation.

xyz_grid-0102-20230424115312 325325235 Z_SD_v1-5-pruned realistic telescope imagery of space Steps: 10, Sampler: DPM++ 2M, CFG scale: 7, Seed: 325325235, Size: 512x512, Model hash: e1441589a6, Model: SD_v1-5-pruned

xyz_grid-0101-20230424114158 325325235 Z_SD_v1-5-pruned

Metachs avatar Apr 24 '23 16:04 Metachs

@Metachs interesting; what about for higher values of CFG Scale, such as 10 or 15?

pbaylies avatar Apr 24 '23 17:04 pbaylies

Similar.

10 Steps xyz_grid-0108-20230424135535 325325235 Z_SD_v1-5-pruned

40 Steps xyz_grid-0109-20230424135535 325325235 Z_SD_v1-5-pruned

10 Steps xyz_grid-0111-20230424140315 325325235 Z_SD_v1-5-pruned

40 Steps xyz_grid-0112-20230424140315 325325235 Z_SD_v1-5-pruned

Metachs avatar Apr 24 '23 21:04 Metachs

In the following comparisons, "DPM++ 2M Test" is the modified function posted earlier in the thread, the loss of detail is extremely obvious. "DPM++ 2M TestX" is an altered version that removes the "h_d" change that averaged h & h_last, which made no sense to me. It isn't as bad, but still shows a loss of detail vs the original implementation.

i made the same change to h_d and prefer the result, seems a halfway point between original and OP mod.

ride5k avatar May 30 '23 19:05 ride5k

In the following comparisons, "DPM++ 2M Test" is the modified function posted earlier in the thread, the loss of detail is extremely obvious. "DPM++ 2M TestX" is an altered version that removes the "h_d" change that averaged h & h_last, which made no sense to me. It isn't as bad, but still shows a loss of detail vs the original implementation.

i made the same change to h_d and prefer the result, seems a halfway point between original and OP mod.

how can i make that change

elen07zz avatar Jun 07 '23 00:06 elen07zz