diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

StableDiffusion: Decode latents separately to run larger batches

Open kig opened this issue 2 years ago • 11 comments

You can use larger batch sizes if you do the VAE decode one image at a time. Now VAE decode runs on the full batch, which limits 8GB GPU batch size and max throughput.

This PR makes VAE decode run one image at a time. This makes 20-image (512x512) batches possible on 8GB GPUs with fp16.

kig avatar Nov 05 '22 20:11 kig

The documentation is not available anymore as the PR was closed or merged.

Hey @kig,

I don't think we should make this the default as it necessarily makes the execution slower - however it might make a lot of sense to add a enable_vae_slicing function for this. Before doing so could you maybe provide a codesnippet that can be run to see the savings in memory?

patrickvonplaten avatar Nov 07 '22 20:11 patrickvonplaten

E.g. a code snippet that works with your PR for 8GB RAM GPU but not for the current pipeline implementation?

patrickvonplaten avatar Nov 07 '22 20:11 patrickvonplaten

Hi @patrickvonplaten!

Yeah, I agree on the enable_vae_slicing approach. Here's a small snippet to test:

from diffusers import StableDiffusionPipeline
import torch
import os

pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5",
    revision="fp16",
    torch_dtype=torch.float16)
pipe.enable_attention_slicing()
# Disable safety_checker for testing, it's triggered by noise.
pipe.safety_checker = None
pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"

for samples in [1, 4, 8, 16, 32]:
    print(f"Generating {samples} image{'s' if samples > 1 else ''}")
    images = pipe([prompt] * samples, num_inference_steps=1).images
    if len(images) != samples:
        raise RuntimeError(f"Expected {samples} images, got {len(images)}")

I added some simple time.time() profiling around the VAE decode too, from before the decode to after the .cpu() call.

image = self.vae.decode(latents).sample results:

$ python test_samples.py
Generating 1 image
100%|â–ˆ| 1/1 [00:01<00:00,  1.70s/it]
VAE decode elapsed 0.10824728012084961
Generating 4 images
100%|â–ˆ| 1/1 [00:00<00:00,  2.49it/s]
VAE decode elapsed 0.905648946762085
Generating 8 images
100%|â–ˆ| 1/1 [00:01<00:00,  1.33s/it]
VAE decode elapsed 2.5215229988098145
Generating 16 images
100%|â–ˆ| 1/1 [00:02<00:00,  2.72s/it]
Traceback (most recent call last):
  File "test_samples.py", line 15, in <module>
    images = pipe([prompt] * samples, num_inference_steps=1).images
...
RuntimeError: CUDA out of memory. Tried to allocate 2.00 GiB (GPU 0; 8.00 GiB total capacity; 5.53 GiB already allocated; 0 bytes free; 6.54 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Weird thing is, the full batch VAE decode time seems to scale superlinearly. Double the batch size and the runtime triples. [edit] Couldn't reproduce it in a Docker image. There the decode time scaled linearly until it got close to 100% memory use and then there's a sudden 10x-50x jump in decode time followed by OOM at batch size 12.

image = torch.cat([self.vae.decode(latent).sample for latent in latents.split(1)]) results:

$ python test_samples.py
Generating 1 image
100%|â–ˆ| 1/1 [00:01<00:00,  1.69s/it]
VAE decode elapsed 0.11412405967712402
Generating 4 images
100%|â–ˆ| 1/1 [00:00<00:00,  2.51it/s]
VAE decode elapsed 0.4106144905090332
Generating 8 images
100%|â–ˆ| 1/1 [00:00<00:00,  1.29it/s]
VAE decode elapsed 0.8257348537445068
Generating 16 images
100%|â–ˆ| 1/1 [00:01<00:00,  1.52s/it]
VAE decode elapsed 1.6556949615478516
Generating 32 images
100%|â–ˆ| 1/1 [00:08<00:00,  8.05s/it]
VAE decode elapsed 3.727773666381836

These scale more or less linearly. The 32 images is starting to hit something though, I had that decode time fluctuate between 3.5 and 8 seconds.

Huh, that was not what I was expecting. I thought the full batch at a time would have some small efficiency benefit from avoiding setup work but looks like there's something else in play.

kig avatar Nov 08 '22 18:11 kig

Testing on a 24GB card, the VAE decode time scales linearly, but it runs into an issue with 32 samples. This with the "full batch at a time"-approach.

# python test.py
Generating 1 image
100%|â–ˆ| 1/1 [00:00<00:00,  3.21it/s]
VAE decode elapsed 0.08451700210571289
Generating 4 images
100%|â–ˆ| 1/1 [00:00<00:00,  3.14it/s]
VAE decode elapsed 0.32993221282958984
Generating 8 images
100%|â–ˆ| 1/1 [00:00<00:00,  2.01it/s]
VAE decode elapsed 0.6512401103973389
Generating 16 images
100%|â–ˆ| 1/1 [00:00<00:00,  1.00it/s]
VAE decode elapsed 1.289898157119751
Generating 32 images
100%|â–ˆ| 1/1 [00:01<00:00,  1.94s/it]
Traceback (most recent call last):
  File "test.py", line 15, in <module>
    images = pipe([prompt] * samples, num_inference_steps=1).images
  File "/opt/conda/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/app/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py", line 403, in __call__
    image = self.vae.decode(latents).sample
  File "/app/diffusers/src/diffusers/models/vae.py", line 581, in decode
    dec = self.decoder(z)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1185, in _call_impl
    return forward_call(*input, **kwargs)
  File "/app/diffusers/src/diffusers/models/vae.py", line 217, in forward
    sample = up_block(sample)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1185, in _call_impl
    return forward_call(*input, **kwargs)
  File "/app/diffusers/src/diffusers/models/unet_2d_blocks.py", line 1322, in forward
    hidden_states = upsampler(hidden_states)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1185, in _call_impl
    return forward_call(*input, **kwargs)
  File "/app/diffusers/src/diffusers/models/resnet.py", line 58, in forward
    hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/functional.py", line 3918, in interpolate
    return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors)
RuntimeError: upsample_nearest_nhwc only supports output tensors with less than INT_MAX elements

[edit] Here doing VAE one image at a time seems to be 15% faster at batch size 8.

kig avatar Nov 08 '22 19:11 kig

Doing VAE one image at a time seems to be 15% faster at batch size 8 and 2% slower at batch size 1. It's a small enough difference that it might be noise. The time difference is 10 ms per image, which doesn't move the needle all that much if an 8-image batch takes 8 seconds to generate.

However, you can go a step further and do tiled VAE decode coupled with xformers to render 8k images on a 24GB GPU...

Prompt: "a beautiful landscape photograph, 8k", 150 seconds per iteration.

img-7680-4320

3840x2160 goes 15x faster, so there's some perf cliff there.

kig avatar Nov 09 '22 09:11 kig

btw, I have a speed boost for the decoder here:
https://github.com/huggingface/diffusers/pull/1203

eliminates a sqrt() and a multiply, simplifies 4D tensor to 3D (num_heads is always 1), uses batch matmul.
yes, it'd be possible to implement sliced attention too, if that's a limiting factor.

Birch-san avatar Nov 10 '22 23:11 Birch-san

Sorry for the late reply, it's been hectic. I added an enable_vae_slicing() function to PipelineStableDiffusion and moved the slicing implementation to AutoencoderKL.

Let me know if you prefer it in the pipeline and I can move it there.

kig avatar Nov 18 '22 02:11 kig

yes, it'd be possible to implement sliced attention too, if that's a limiting factor.

Thanks! I tried making the VAE use xformers attention and that did help with memory use. But the Resnet convolution layers in src/diffusers/models/resnet.py Upsample2D and ResnetBlock2D turned out to be another limitation.

kig avatar Nov 18 '22 04:11 kig

Hey @kig,

Great the API looks very nice to me now :-)

Could we do two last things:

  • 1.) Adds docs
  • 2.) Adds tests

Maybe adding two links that might help: For:

  • 1.) Docs could look like: https://github.com/huggingface/diffusers/blob/main/docs/source/optimization/fp16.mdx#sliced-attention-for-additional-memory-savings as well as adding your function here: https://github.com/huggingface/diffusers/blob/main/docs/source/api/pipelines/stable_diffusion.mdx#stablediffusionpipeline
  • 2.) Tests could look like: https://github.com/huggingface/diffusers/blob/ab1f01e63415b63937736299d3a770554c83987e/tests/pipelines/stable_diffusion/test_stable_diffusion.py#L468 and https://github.com/huggingface/diffusers/blob/ab1f01e63415b63937736299d3a770554c83987e/tests/pipelines/stable_diffusion/test_stable_diffusion.py#L743

Let me know if you need more pointers :-)

patrickvonplaten avatar Nov 20 '22 19:11 patrickvonplaten

@patrickvonplaten here we go, I added tests and docs. Let me know how they look.

kig avatar Nov 23 '22 15:11 kig

Hey @kig,

Awesome job :-) Merging this PR!

patrickvonplaten avatar Nov 29 '22 12:11 patrickvonplaten