ComfyUI-layerdiffuse icon indicating copy to clipboard operation
ComfyUI-layerdiffuse copied to clipboard

Mac crash

Open jerlinn opened this issue 11 months ago โ€ข 27 comments

Docode node always crash

๐Ÿ’ป Mac M2

โŒ export PYTORCH_ENABLE_MPS_FALLBACK=1 โŒ --force-fp16 โŒ just python main.py

CleanShot 2024-03-03 at 14 44 43@2x

jerlinn avatar Mar 03 '24 06:03 jerlinn

Same running into this issue :(

ynie avatar Mar 03 '24 16:03 ynie

M1 Same running into this issue :( +1

yiwangsimple avatar Mar 04 '24 08:03 yiwangsimple

I do not have a Macbook with M-series chip. Can you help confirm that whether the issue exists for SD Forge's impl as well? https://github.com/layerdiffusion/sd-forge-layerdiffusion

huchenlei avatar Mar 04 '24 17:03 huchenlei

ๆˆชๅฑ2024-03-05 10 35 36 My observation today is that it's the error that occurs when executing this node that causes Python to crash outright

yiwangsimple avatar Mar 05 '24 02:03 yiwangsimple

Any Mac user tried https://github.com/layerdiffusion/sd-forge-layerdiffusion with any success?

huchenlei avatar Mar 06 '24 02:03 huchenlei

@huchenlei Hi, Bro. Have you tried to fix this problem? I also have an MacOS with an M1 chip, and I have this problem.

Tukeping avatar Mar 06 '24 04:03 Tukeping

@huchenlei Hi, Bro. Have you tried to fix this problem? I also have an MacOS with an M1 chip, and I have this problem.

I would like to first confirm whether this issue is comfyui specific or for SD Forge as well.

huchenlei avatar Mar 06 '24 04:03 huchenlei

@huchenlei Hi, Bro. Have you tried to fix this problem? I also have an MacOS with an M1 chip, and I have this problem.

I would like to first confirm whether this issue is comfyui specific or for SD Forge as well.

@huchenlei I tried running sd-forge-layerdiffuse and it reported the same error message.

------ Loggers ------------- Below --------- Running on local URL: http://127.0.0.1:7860

To create a public link, set share=True in launch(). model_type EPS UNet ADM Dimension 2816 Startup time: 14.5s (prepare environment: 0.4s, import torch: 5.2s, import gradio: 1.6s, setup paths: 2.1s, other imports: 3.0s, load scripts: 1.1s, create ui: 0.4s, gradio launch: 0.6s). Using split attention in VAE Working with z of shape (1, 4, 32, 32) = 4096 dimensions. Using split attention in VAE extra {'cond_stage_model.clip_g.transformer.text_model.embeddings.position_ids', 'cond_stage_model.clip_l.transformer.text_model.embeddings.position_ids', 'cond_stage_model.clip_l.logit_scale', 'cond_stage_model.clip_l.text_projection', 'cond_stage_model.clip_g.logit_scale'} To load target model SDXLClipModel Begin to load 1 model Moving model(s) has taken 0.00 seconds Model loaded in 7.7s (load weights from disk: 0.6s, forge load real models: 5.7s, calculate empty prompt: 1.3s). [Layer Diffusion] LayerMethod.FG_ONLY_ATTN To load target model SDXL Begin to load 1 model Moving model(s) has taken 3.52 seconds 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 20/20 [00:11<00:00, 1.69it/s] To load target model AutoencoderKLโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 20/20 [00:10<00:00, 1.81it/s] Begin to load 1 model Moving model(s) has taken 0.10 seconds To load target model UNet1024 Begin to load 1 model Moving model(s) has taken 0.27 seconds 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 8/8 [00:03<00:00, 2.16it/s] /AppleInternal/Library/BuildRoots/0032d1ee-80fd-11ee-8227-6aecfccc70fe/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSNDArray/Kernels/MPSNDArraySort.mm:287: failed assertion `(null)" Axis = 4. This class only supports axis = 0, 1, 2, 3 ' ./webui.sh: line 292: 13941 Abort trap: 6 "${python_cmd}" -u "${LAUNCH_SCRIPT}" "$@" /opt/homebrew/Caskroom/miniconda/base/lib/python3.10/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown warnings.warn('resource_tracker: There appear to be %d '

Tukeping avatar Mar 06 '24 05:03 Tukeping

M2 Same running into this issue :( +1

gabrie avatar Mar 09 '24 03:03 gabrie

Maybe in MPS frameworks, it can not sort the batch_zise dimension.

jerlinn avatar Mar 09 '24 04:03 jerlinn

M2 Pro+32G RAM get the same issue.

tilseam avatar Mar 10 '24 13:03 tilseam

ๆˆ‘็š„ไนŸๆ˜ฏ๏ผŒmac M2 ๅทฅไฝœๆตไธ€ๆ—ฆ่ฟ่กŒๅˆฐLayerDiffusion่งฃ็ (RGBA)๏ผŒpythonๅฐฑไผš่‡ชๅŠจไธญๆ–ญ่ฟ่กŒ๏ผŒ่ทณๅ‡บ้”™่ฏฏๆ็คบ

BannyLon avatar Mar 11 '24 03:03 BannyLon

The mac doesn't have to get hung up on it. You can replace the node function with another process.

yiwangsimple avatar Mar 11 '24 03:03 yiwangsimple

The mac doesn't have to get hung up on it. You can replace the node function with another process.

How to replace node function with another process.

BannyLon avatar Mar 11 '24 03:03 BannyLon

same error, so how to resolve it

hike2008 avatar Mar 15 '24 10:03 hike2008

Looking forward to professional solutions, thank you very much!

BannyLon avatar Mar 19 '24 01:03 BannyLon

The popularity of plugins is so high, how can this common bug be ignored and there are no experts to help solve it!

BannyLon avatar Mar 19 '24 01:03 BannyLon

same issue with mac m2. it seems problems related with the MPS framework. Sort data along a dimension (axis) that the MPS framework doesn't support. Currently, it can only handle sorting along the first 4 dimensions of an N-dimensional array while we get 5 dimensions in this situation.

/AppleInternal/Library/BuildRoots/0032d1ee-80fd-11ee-8227-6aecfccc70fe/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSNDArray/Kernels/MPSNDArraySort.mm:287: failed assertion `(null)" Axis = 4. This class only supports axis = 0, 1, 2, 3

dingshanliang avatar Apr 08 '24 17:04 dingshanliang

M3 pro max + 128g get the same issue.

nigo81 avatar Apr 09 '24 13:04 nigo81

same issue with Mac M2 ultra. it does seem to be an MPS issue. the workaround I got is to use a custom node called "ImageSelector". you can apply this directly after vae decoder to select the layer you want. however, this will result in losing the matted FG (RGBA result with mask). use the [Generate BG + FG + Blended together] workflow as an example, apply 3 ImageSelectors after vae decode, and set selected_index each to 1, 2, 3 on the three selectors. what you get can be the blended FG+BG๏ผŒ separate FG (with a gray background, not transparent), and the BG it generated. I suppose there can be methods to get the mask again either by using a different selector that supports the mask; or use applying seg methods to create a new mask based on the FG's gray background. haven't tested anything yet. but you are welcome to try it by yourself.

franksuni avatar Apr 18 '24 04:04 franksuni

The workaround below works fine on my m1max macbook.

In file lib_layerdiffusion/models.py, find

median = torch.median(result, dim=0).values

and modify like this

if self.load_device == torch.device("mps"):
            '''
            In case that apple silicon devices would crash when calling torch.median() on tensors
            in gpu vram with dimensions higher than 4, we move it to cpu, call torch.median()
            and then move the result back to gpu.
            '''
            median = torch.median(result.cpu(), dim=0).values
            median = median.to(device=self.load_device, dtype=self.dtype)
else:
            median = torch.median(result, dim=0).values

Save and restart ComfyUI.

PS: I tried to make a pull request, but it turns out I have no access rights to this repo. I don't know if the repo's author have configured the repo settings.

ERROR: Permission to huchenlei/ComfyUI-layerdiffuse.git denied to devgdovg.
fatal: Could not read from remote repository.

Please make sure you have the correct access rights

Anyone can help me with this? thanks.

devgdovg avatar Apr 22 '24 17:04 devgdovg

to reproduce the crash in a simple scenario, try the code below on your m series macbook:

import torch
a = torch.randn(8, 1, 4, 512, 512)
mps_device = torch.device("mps")
b = a.to(mps_device)
tt = torch.median(b, dim=0) # crash here

but if you try a tensor with lower dimension, eg. a = torch.randn(8, 1, 4, 512), there will be no crash

devgdovg avatar Apr 22 '24 18:04 devgdovg

The workaround below works fine on my m1max macbook.

In file lib_layerdiffusion/models.py, find

median = torch.median(result, dim=0).values

and modify like this

if self.load_device == torch.device("mps"):
            '''
            In case that apple silicon devices would crash when calling torch.median() on tensors
            in gpu vram with dimensions higher than 4, we move it to cpu, call torch.median()
            and then move the result back to gpu.
            '''
            median = torch.median(result.cpu(), dim=0).values
            median = median.to(device=self.load_device, dtype=self.dtype)
else:
            median = torch.median(result, dim=0).values

Save and restart ComfyUI.

PS: I tried to make a pull request, but it turns out I have no access rights to this repo. I don't know if the repo's author have configured the repo settings.

ERROR: Permission to huchenlei/ComfyUI-layerdiffuse.git denied to devgdovg.
fatal: Could not read from remote repository.

Please make sure you have the correct access rights

Anyone can help me with this? thanks.

ๆˆ‘ๆ˜ฏ M1 ็”จๆˆทใ€‚ๆŒ‰็…งๅคงไฝฌ็š„ๆ–นๆณ•ๆˆ‘ๆˆๅŠŸ็š„่งฃๅ†ณ่ฟ™ไธช้”™่ฏฏไบ†ใ€‚

image

forgetphp avatar May 05 '24 09:05 forgetphp

The workaround below works fine on my m1max macbook. In file lib_layerdiffusion/models.py, find

median = torch.median(result, dim=0).values

and modify like this

if self.load_device == torch.device("mps"):
            '''
            In case that apple silicon devices would crash when calling torch.median() on tensors
            in gpu vram with dimensions higher than 4, we move it to cpu, call torch.median()
            and then move the result back to gpu.
            '''
            median = torch.median(result.cpu(), dim=0).values
            median = median.to(device=self.load_device, dtype=self.dtype)
else:
            median = torch.median(result, dim=0).values

Save and restart ComfyUI. PS: I tried to make a pull request, but it turns out I have no access rights to this repo. I don't know if the repo's author have configured the repo settings.

ERROR: Permission to huchenlei/ComfyUI-layerdiffuse.git denied to devgdovg.
fatal: Could not read from remote repository.

Please make sure you have the correct access rights

Anyone can help me with this? thanks.

ๆˆ‘ๆ˜ฏ M1 ็”จๆˆทใ€‚ๆŒ‰็…งๅคงไฝฌ็š„ๆ–นๆณ•ๆˆ‘ๆˆๅŠŸ็š„่งฃๅ†ณ่ฟ™ไธช้”™่ฏฏไบ†ใ€‚

image

ๅฐ็ป†่Š‚๏ผšไธ็Ÿฅ้“ไธบไป€ไนˆ็”จๆ–‡ๆœฌ็ผ–่พ‘ๅ™จๅ’Œxcodeไฟฎๆ”น๏ผŒๆ’ไปถไผšๅ‡บ้”™่ฟ่กŒไธไบ†ใ€‚็”จVSCodeไฟฎๆ”นๆˆๅŠŸใ€‚

tilseam avatar May 08 '24 07:05 tilseam

The workaround below works fine on my m1max macbook. In file lib_layerdiffusion/models.py, find

median = torch.median(result, dim=0).values

and modify like this

if self.load_device == torch.device("mps"):
            '''
            In case that apple silicon devices would crash when calling torch.median() on tensors
            in gpu vram with dimensions higher than 4, we move it to cpu, call torch.median()
            and then move the result back to gpu.
            '''
            median = torch.median(result.cpu(), dim=0).values
            median = median.to(device=self.load_device, dtype=self.dtype)
else:
            median = torch.median(result, dim=0).values

Save and restart ComfyUI. PS: I tried to make a pull request, but it turns out I have no access rights to this repo. I don't know if the repo's author have configured the repo settings.

ERROR: Permission to huchenlei/ComfyUI-layerdiffuse.git denied to devgdovg.
fatal: Could not read from remote repository.

Please make sure you have the correct access rights

Anyone can help me with this? thanks.

ๆˆ‘ๆ˜ฏ M1 ็”จๆˆทใ€‚ๆŒ‰็…งๅคงไฝฌ็š„ๆ–นๆณ•ๆˆ‘ๆˆๅŠŸ็š„่งฃๅ†ณ่ฟ™ไธช้”™่ฏฏไบ†ใ€‚ image

ๅฐ็ป†่Š‚๏ผšไธ็Ÿฅ้“ไธบไป€ไนˆ็”จๆ–‡ๆœฌ็ผ–่พ‘ๅ™จๅ’Œxcodeไฟฎๆ”น๏ผŒๆ’ไปถไผšๅ‡บ้”™่ฟ่กŒไธไบ†ใ€‚็”จVSCodeไฟฎๆ”นๆˆๅŠŸใ€‚

ๅŒๆ˜ฏm1๏ผŒ่ƒฝๅฆไธŠไผ ไฝ ไฟฎๆ”น็š„ๆ–‡ไปถ๏ผŒ่ฎฉๆˆ‘ไปฌ่ฆ†็›–ๅŽŸๆ–‡ไปถ่ฏ•่ฏ•

feihuang520 avatar May 08 '24 10:05 feihuang520

ๅŒๆ˜ฏm1๏ผŒ่ƒฝๅฆไธŠไผ ไฝ ไฟฎๆ”น็š„ๆ–‡ไปถ๏ผŒ่ฎฉๆˆ‘ไปฌ่ฆ†็›–ๅŽŸๆ–‡ไปถ่ฏ•่ฏ• models.py.zip

tilseam avatar May 13 '24 10:05 tilseam

ๅŒๆ˜ฏm1๏ผŒ่ƒฝๅฆไธŠไผ ไฝ ไฟฎๆ”น็š„ๆ–‡ไปถ๏ผŒ่ฎฉๆˆ‘ไปฌ่ฆ†็›–ๅŽŸๆ–‡ไปถ่ฏ•่ฏ• models.py.zip @tilseam ๆ‚จๅฅฝ๏ผไปฅไธ‹ๆ˜ฏๆˆ‘ไฟฎๆ”นๅŽ็š„ๆบๆ–‡ไปถใ€‚


import torch.nn as nn
import torch
import cv2
import numpy as np

from tqdm import tqdm
from typing import Optional, Tuple
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block


def check_diffusers_version():
    import diffusers
    from packaging.version import parse

    assert parse(diffusers.__version__) >= parse(
        "0.25.0"
    ), "diffusers>=0.25.0 requirement not satisfied. Please install correct diffusers version."


check_diffusers_version()


def zero_module(module):
    """
    Zero out the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().zero_()
    return module


class LatentTransparencyOffsetEncoder(torch.nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.blocks = torch.nn.Sequential(
            torch.nn.Conv2d(4, 32, kernel_size=3, padding=1, stride=1),
            nn.SiLU(),
            torch.nn.Conv2d(32, 32, kernel_size=3, padding=1, stride=1),
            nn.SiLU(),
            torch.nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=2),
            nn.SiLU(),
            torch.nn.Conv2d(64, 64, kernel_size=3, padding=1, stride=1),
            nn.SiLU(),
            torch.nn.Conv2d(64, 128, kernel_size=3, padding=1, stride=2),
            nn.SiLU(),
            torch.nn.Conv2d(128, 128, kernel_size=3, padding=1, stride=1),
            nn.SiLU(),
            torch.nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2),
            nn.SiLU(),
            torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1),
            nn.SiLU(),
            zero_module(torch.nn.Conv2d(256, 4, kernel_size=3, padding=1, stride=1)),
        )

    def __call__(self, x):
        return self.blocks(x)


# 1024 * 1024 * 3 -> 16 * 16 * 512 -> 1024 * 1024 * 3
class UNet1024(ModelMixin, ConfigMixin):
    @register_to_config
    def __init__(
        self,
        in_channels: int = 3,
        out_channels: int = 3,
        down_block_types: Tuple[str] = (
            "DownBlock2D",
            "DownBlock2D",
            "DownBlock2D",
            "DownBlock2D",
            "AttnDownBlock2D",
            "AttnDownBlock2D",
            "AttnDownBlock2D",
        ),
        up_block_types: Tuple[str] = (
            "AttnUpBlock2D",
            "AttnUpBlock2D",
            "AttnUpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
        ),
        block_out_channels: Tuple[int] = (32, 32, 64, 128, 256, 512, 512),
        layers_per_block: int = 2,
        mid_block_scale_factor: float = 1,
        downsample_padding: int = 1,
        downsample_type: str = "conv",
        upsample_type: str = "conv",
        dropout: float = 0.0,
        act_fn: str = "silu",
        attention_head_dim: Optional[int] = 8,
        norm_num_groups: int = 4,
        norm_eps: float = 1e-5,
    ):
        super().__init__()

        # input
        self.conv_in = nn.Conv2d(
            in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)
        )
        self.latent_conv_in = zero_module(
            nn.Conv2d(4, block_out_channels[2], kernel_size=1)
        )

        self.down_blocks = nn.ModuleList([])
        self.mid_block = None
        self.up_blocks = nn.ModuleList([])

        # down
        output_channel = block_out_channels[0]
        for i, down_block_type in enumerate(down_block_types):
            input_channel = output_channel
            output_channel = block_out_channels[i]
            is_final_block = i == len(block_out_channels) - 1

            down_block = get_down_block(
                down_block_type,
                num_layers=layers_per_block,
                in_channels=input_channel,
                out_channels=output_channel,
                temb_channels=None,
                add_downsample=not is_final_block,
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
                resnet_groups=norm_num_groups,
                attention_head_dim=(
                    attention_head_dim
                    if attention_head_dim is not None
                    else output_channel
                ),
                downsample_padding=downsample_padding,
                resnet_time_scale_shift="default",
                downsample_type=downsample_type,
                dropout=dropout,
            )
            self.down_blocks.append(down_block)

        # mid
        self.mid_block = UNetMidBlock2D(
            in_channels=block_out_channels[-1],
            temb_channels=None,
            dropout=dropout,
            resnet_eps=norm_eps,
            resnet_act_fn=act_fn,
            output_scale_factor=mid_block_scale_factor,
            resnet_time_scale_shift="default",
            attention_head_dim=(
                attention_head_dim
                if attention_head_dim is not None
                else block_out_channels[-1]
            ),
            resnet_groups=norm_num_groups,
            attn_groups=None,
            add_attention=True,
        )

        # up
        reversed_block_out_channels = list(reversed(block_out_channels))
        output_channel = reversed_block_out_channels[0]
        for i, up_block_type in enumerate(up_block_types):
            prev_output_channel = output_channel
            output_channel = reversed_block_out_channels[i]
            input_channel = reversed_block_out_channels[
                min(i + 1, len(block_out_channels) - 1)
            ]

            is_final_block = i == len(block_out_channels) - 1

            up_block = get_up_block(
                up_block_type,
                num_layers=layers_per_block + 1,
                in_channels=input_channel,
                out_channels=output_channel,
                prev_output_channel=prev_output_channel,
                temb_channels=None,
                add_upsample=not is_final_block,
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
                resnet_groups=norm_num_groups,
                attention_head_dim=(
                    attention_head_dim
                    if attention_head_dim is not None
                    else output_channel
                ),
                resnet_time_scale_shift="default",
                upsample_type=upsample_type,
                dropout=dropout,
            )
            self.up_blocks.append(up_block)
            prev_output_channel = output_channel

        # out
        self.conv_norm_out = nn.GroupNorm(
            num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
        )
        self.conv_act = nn.SiLU()
        self.conv_out = nn.Conv2d(
            block_out_channels[0], out_channels, kernel_size=3, padding=1
        )

    def forward(self, x, latent):
        sample_latent = self.latent_conv_in(latent)
        sample = self.conv_in(x)
        emb = None

        down_block_res_samples = (sample,)
        for i, downsample_block in enumerate(self.down_blocks):
            if i == 3:
                sample = sample + sample_latent

            sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
            down_block_res_samples += res_samples

        sample = self.mid_block(sample, emb)

        for upsample_block in self.up_blocks:
            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
            down_block_res_samples = down_block_res_samples[
                : -len(upsample_block.resnets)
            ]
            sample = upsample_block(sample, res_samples, emb)

        sample = self.conv_norm_out(sample)
        sample = self.conv_act(sample)
        sample = self.conv_out(sample)
        return sample


def checkerboard(shape):
    return np.indices(shape).sum(axis=0) % 2


def fill_checkerboard_bg(y: torch.Tensor) -> torch.Tensor:
    alpha = y[..., :1]
    fg = y[..., 1:]
    B, H, W, C = fg.shape
    cb = checkerboard(shape=(H // 64, W // 64))
    cb = cv2.resize(cb, (W, H), interpolation=cv2.INTER_NEAREST)
    cb = (0.5 + (cb - 0.5) * 0.1)[None, ..., None]
    cb = torch.from_numpy(cb).to(fg)
    vis = fg * alpha + cb * (1 - alpha)
    return vis


class TransparentVAEDecoder:
    def __init__(self, sd, device, dtype):
        self.load_device = device
        self.dtype = dtype

        model = UNet1024(in_channels=3, out_channels=4)
        model.load_state_dict(sd, strict=True)
        model.to(self.load_device, dtype=self.dtype)
        model.eval()
        self.model = model

    @torch.no_grad()
    def estimate_single_pass(self, pixel, latent):
        y = self.model(pixel, latent)
        return y

    @torch.no_grad()
    def estimate_augmented(self, pixel, latent):
        args = [
            [False, 0],
            [False, 1],
            [False, 2],
            [False, 3],
            [True, 0],
            [True, 1],
            [True, 2],
            [True, 3],
        ]

        result = []

        for flip, rok in tqdm(args):
            feed_pixel = pixel.clone()
            feed_latent = latent.clone()

            if flip:
                feed_pixel = torch.flip(feed_pixel, dims=(3,))
                feed_latent = torch.flip(feed_latent, dims=(3,))

            feed_pixel = torch.rot90(feed_pixel, k=rok, dims=(2, 3))
            feed_latent = torch.rot90(feed_latent, k=rok, dims=(2, 3))

            eps = self.estimate_single_pass(feed_pixel, feed_latent).clip(0, 1)
            eps = torch.rot90(eps, k=-rok, dims=(2, 3))

            if flip:
                eps = torch.flip(eps, dims=(3,))

            result += [eps]

        result = torch.stack(result, dim=0)
        median = torch.median(result, dim=0).values
        return median

    @torch.no_grad()
    def decode_pixel(
        self, pixel: torch.TensorType, latent: torch.TensorType
    ) -> torch.TensorType:
        # pixel.shape = [B, C=3, H, W]
        assert pixel.shape[1] == 3
        pixel_device = pixel.device
        pixel_dtype = pixel.dtype

        pixel = pixel.to(device=self.load_device, dtype=self.dtype)
        latent = latent.to(device=self.load_device, dtype=self.dtype)
        # y.shape = [B, C=4, H, W]
        y = self.estimate_augmented(pixel, latent)
        y = y.clip(0, 1)
        assert y.shape[1] == 4
        # Restore image to original device of input image.
        return y.to(pixel_device, dtype=pixel_dtype)

forgetphp avatar May 13 '24 10:05 forgetphp