diffusers
diffusers copied to clipboard
8k Stable Diffusion with tiled VAE
This PR makes it possible to generate 4k images in 8GB of VRAM using a tiled VAE codec combined with enable_xformers_memory_efficient_attention()
. With 24GB of VRAM, you can generate 8k images.
The tiled codec splits the input into overlapping tiles, processes the tiles sequentially, and blends the output tiles together for the final output.
import torch
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", revision="fp16", torch_dtype=torch.float16, use_auth_token=True)
pipe = pipe.to("cuda")
pipe.enable_xformers_memory_efficient_attention()
pipe.vae.enable_tiling()
prompt = "a beautiful landscape photo"
image = pipe(prompt, width=4096, height=2048, num_inference_steps=10).images[0]
image.save("4k_landscape.jpg")
It's not perfect. Each tile has a different decoder so uniform surfaces tend to have tile-to-tile tone variation. You also want to disable this for smaller images, which I'm doing at the pipeline level.
If the tiling artifacts are giving you grief there's another way to do this by adding xformers support to VAE, switching the VAE to channels_last memory format, and running the up_blocks on the CPU.
But that's a different PR.
Example output:
The documentation is not available anymore as the PR was closed or merged.
This looks really nice already, nice job @kig ! Just to better understand, are there multiple ways of how tiling can be implemented? Is there a reference paper / implementation for this?
Could we add a test for this?
Thanks @patrickvonplaten ! I don't think there's a reference paper / implementation for this, it's based on experimentation. I might be wrong though and maybe there's a paper out there discussing this. Idea-wise it's similar to the GOBIG upscaler.
And yes, let me write a test for it. I guess a smoke test and one that verifies that 512x512 output matches non-tiled, and one where 1024x1024 output is mostly similar to non-tiled? Would that last one run ok on the test infra? :)
The tiling can be done at least in a couple of ways that I tried:
- Split the latents into 64x64 non-overlapping tiles and decode them separately. Produces sharp seams between the tiles.
- Add an overlapping border around the tiles (say, pad each tile by a border with width 64, so your (1,4,64,64) tile becomes (1,4,192,192)) and decode them separately, using only the middle part for the output image. This is mostly seamless since the decoders see the neighboring tile latents, but you can still get seams in flat areas. And the per-tile processing time and memory use go way up.
- Decode 64x64 tiles, but overlap each tile with the tile on the left and the tile above it. Blend the overlap in the output tiles with a lerp. No visible seams in the output, and no increase in per-tile memory use (all the tiles are 64x64), but the processing takes longer compared to non-overlapping tiles. This is what the PR code is doing, 64x64 tiles with a 48 px stride.
The nicest way to do this would be to make the VAE attention use xformers and make the convolution layers run in a fixed amount of memory. That way it'd produce actually correct results...
There's a PR for the xformers #1507 and I got the convolution layers bit sort of working with channels_last memory format, but they still use tens of GB of RAM -- that one's in https://github.com/kig/diffusers/blob/sd-vae-hires/src/diffusers/models/vae.py#L298 but it's quite messy.
[Going on a tangent.] Profiling the memory use a bit further, running the non-tiled decoder with limited memory seems tricky. The decoder images have channel counts ranging from 512 to 128. The convolutions do run memory-efficiently with channels_last, but if your input image is 8 GB and your output image is 4 GB, you're going to need 12 GB.
VAE Decoder forward()
input is a 4-channel image and the output is a 3-channel image. First (pardon the infodump) conv_in goes 4->512, then mid_block 512->512, up_blocks [512 -> 2x res 512 -> 4x res 256 -> 8x res 128], conv_norm_out 128->128, conv_act 128->128, conv_out 128->3. Peak memory use happens on the last up_block when it F.interpolate
s to 8x res 256c and then conv2d
s that to 8x res 128c.
Only the mid_block has an attention layer, the others are a mix of Conv2d, GroupNorm, Dropout and SiLU. I guess the tiling artifacts would come from the mid_block and the GroupNorms. The mid_block can be run on the full image, it's not very memory-intensive. The rest of the pipeline you'd have to tile. Fixing tiling artifacts coming from GroupNorms... I guess you could create a downsampled version of the image, compute the group norm parameters for the image, and apply those instead of the dynamically computed per-tile group norm.
In a nutshell, tile the image after mid_block, replace up_block GroupNorms with a fixed whole-image group norm, run the rest of the pipeline tiled, put the tiles back together for the decoder output.
Sorry for being so slow here - will try to look into it this week!
@patil-suraj could you pick this up maybe?
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Now that the UNet has various options for slicing and memory-efficient attention, it's not uncommon to generate results that are bigger than fit through the VAE.
I haven't reviewed this myself but it sounds like it could be one way to address that problem.
I have tried to split the vae decoder's upsampling part. I confirm that the seams is from global-aware operators, specifically the attention, most of which can be safely removed except the groupnorms inside the ResNet block. How can I keep the mean and the variance of these group norm the same?
Not that we have also a "slice_vae" functionality: https://huggingface.co/docs/diffusers/v0.13.0/en/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline.enable_vae_slicing but this only makes sense for higher batch sizes so I guess this is still very much relevant
Can reproduce the results from above when using the faster UniPC sampler (just 20 steps).
I have tried to split the vae decoder's upsampling part. I confirm that the seams is from global-aware operators, specifically the attention, most of which can be safely removed except the groupnorms inside the ResNet block. How can I keep the mean and the variance of these group norm the same?
@pkuliyi2015 Looking at the group_norm source, it shouldn't be too difficult to make a custom Python version.
From what I understood digging into the PyTorch ATen C++ & CUDA implementations, the mean and standard deviation are computed inside the kernels (e.g. group_norm_kernel.cu calls into RowwiseMomentsCUDAKernel which is using WelfordOps to compute the mean and standard deviation.) Making them use custom params would require changing PyTorch or adding a new custom op.
Hello, I have completed a wild hack that achieves exactly what you may want! https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111 This is an upscaler WITHOUT image post-processing. Everything is in latent space. The repo includes a wild hack that breaks VAE into task queues. Please refer to my repo for details!
I have tried to split the vae decoder's upsampling part. I confirm that the seams is from global-aware operators, specifically the attention, most of which can be safely removed except the groupnorms inside the ResNet block. How can I keep the mean and the variance of these group norm the same?
@pkuliyi2015 Looking at the group_norm source, it shouldn't be too difficult to make a custom Python version.
From what I understood digging into the PyTorch ATen C++ & CUDA implementations, the mean and standard deviation are computed inside the kernels (e.g. group_norm_kernel.cu calls into RowwiseMomentsCUDAKernel which is using WelfordOps to compute the mean and standard deviation.) Making them use custom params would require changing PyTorch or adding a new custom op.
I have completed a tricky optimization on VAEs. After tons of tricks I found my implementation to be nearly perfect in terms of no seams, except that you must use fp32 VAE for 8K images otherwise it report NANs. You also need giant CPU RAM (~ 85GB for 8k images) to store intermediate tensors. My hack is implemented as an Automatic1111's WebUI extensions, with recommended and user-changeable tiling sizes to fit their own GPU VRAMs. See https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111
I have tried to split the vae decoder's upsampling part. I confirm that the seams is from global-aware operators, specifically the attention, most of which can be safely removed except the groupnorms inside the ResNet block. How can I keep the mean and the variance of these group norm the same?
@pkuliyi2015 Looking at the group_norm source, it shouldn't be too difficult to make a custom Python version.
From what I understood digging into the PyTorch ATen C++ & CUDA implementations, the mean and standard deviation are computed inside the kernels (e.g. group_norm_kernel.cu calls into RowwiseMomentsCUDAKernel which is using WelfordOps to compute the mean and standard deviation.) Making them use custom params would require changing PyTorch or adding a new custom op.
the current vae.enable_tiling()
doesn't work well according my test. the tiled VAE implementation by @pkuliyi2015 perform better !
eg:
StableDiffusionUpscalePipeline (turn on vae.enable_tiling()
) vs StableDiffusionUpscaleTiledVAEPipeline(modified myself using tiled VAE in decoder ),
input image resolution is 362x512 , upscale 4x to 1448 × 2048 by model stabilityai/stable-diffusion-x4-upscaler
vae.enable_tiling
tiled VAE
the defect also occurs in other images when using vae.enable_tiling
, but TiledVAE
works well.
I have tried to split the vae decoder's upsampling part. I confirm that the seams is from global-aware operators, specifically the attention, most of which can be safely removed except the groupnorms inside the ResNet block. How can I keep the mean and the variance of these group norm the same?
@pkuliyi2015 Looking at the group_norm source, it shouldn't be too difficult to make a custom Python version. From what I understood digging into the PyTorch ATen C++ & CUDA implementations, the mean and standard deviation are computed inside the kernels (e.g. group_norm_kernel.cu calls into RowwiseMomentsCUDAKernel which is using WelfordOps to compute the mean and standard deviation.) Making them use custom params would require changing PyTorch or adding a new custom op.
the current
vae.enable_tiling()
doesn't work well according my test. the tiled VAE implementation by @pkuliyi2015 perform better !eg: StableDiffusionUpscalePipeline (turn on
vae.enable_tiling()
) vs StableDiffusionUpscaleTiledVAEPipeline(modified myself using tiled VAE in decoder ), input image resolution is 362x512 , upscale 4x to 1448 × 2048 by modelstabilityai/stable-diffusion-x4-upscaler
vae.enable_tiling
tiled VAE
the defect also occurs in other images when using
vae.enable_tiling
, butTiledVAE
works well.
Wow, this looks great, avoids the burnt out spots! Do you have code or a PR for the TiledVAE?
Would be cool to see a PR here :-)
Hello I'm the original author of this engineering implementation. I'm willing to make a pr but I have little time,As I'm exploring the possibility of tiling unet (not folding it!) So I will do this PR when I have time. But you are welcome to make a pr to my repo as a new branch. My new workload is extremely heavy and I still needed very much time to do that.
Hello I'm the original author of this engineering implementation. I'm willing to make a pr but I have little time,As I'm exploring the possibility of tiling unet (not folding it!) So I will do this PR when I have time. But you are welcome to make a pr to my repo as a new branch. My new workload is extremely heavy and I still needed very much time to do that.
Yes ! I simply incorporated tiledVAE into the StableDiffusionUpscalePipeline, and made some code modifications to adapt the model structure definition of the decoder portion in the diffusers library. Looking forward your PR. but , about tiled UNET, it my not suitable in Text driven UpscalePipeline. The different blocks in a large image may contain different semantic information, and if diffusion is performed on the segmented blocks, the text information required by Unet may be different, which means that the text also needs to correspond to each image block, which may be challenging to handle.
Yes it is very challenging.
However, the text information is injected via the cross attention mechanism, where the (QK/d^-2) * V is a linear process. so you can just calc a small picture's QK, and then do a bilinear interpolation on the result matrix...This is very novel and tricky, ah?
This breakes the tiling (seamless) option, seems by not connectin/overlapping outer edges over the opposite tile edges. Could this be fixed? To generate 8K seamless tiles?
Hey @GitHub1712,
Could you open a new issue here?