bumblebee icon indicating copy to clipboard operation
bumblebee copied to clipboard

Reduce StableDiffusion memory usage

Open josevalim opened this issue 2 years ago • 14 comments

A list of ideas to explore:

  • [x] Lazy transfers (so we don't load data into the GPU at once)
  • [x] FP16 on load
  • [x] FP16 policies on Axon
  • [x] ~Attention slicing~ (no longer applicable https://github.com/huggingface/diffusers/issues/4487)
  • [x] ~Flash attention (JAX version)~ (see notes in https://github.com/elixir-nx/bumblebee/pull/300)
  • [ ] DPM-Solver++ (more schedulers here, here, and in the comments below) (another PyTorch implementation)
  • [ ] TokenMerging
  • [ ] LCM+LoRA
  • [x] ~DeepCache~ (not applicable https://github.com/elixir-nx/bumblebee/issues/147#issuecomment-1963787773)

josevalim avatar Jan 12 '23 21:01 josevalim

More on attention: https://pytorch.org/blog/flash-decoding/

josevalim avatar Oct 13 '23 21:10 josevalim

I'd also suggest FlashAttention-2 and Medusa

bfolkens avatar Nov 13 '23 14:11 bfolkens

Alternative to DPM Solver: https://arxiv.org/abs/2311.05556

josevalim avatar Nov 29 '23 01:11 josevalim

More notes on optimizations here:

  • https://huggingface.co/docs/diffusers/main/en/stable_diffusion
  • https://huggingface.co/docs/diffusers/main/en/optimization/fp16
  • https://huggingface.co/docs/diffusers/main/en/optimization/memory

josevalim avatar Dec 12 '23 20:12 josevalim

I tested SD v1-4 on a GPU using the new lower precision options params_variant: "fp16", type: :bf16. Here are a couple runs:

Type Steps Batch, Images Time Memory Lazy transfers
bf16 20 1, 1 0.7s 4669MiB No
bf16 20 1, 4 2.2s 8769MiB No
f32 20 1, 1 1.3s 8759MiB No
f32 20 1, 4 4.3s 16951MiB No
bf16 20 1, 1 3.7s 6957MiB Yes
f32 20 1, 1 8.2s 13379MiB Yes

Note that the reported memory is just the final memory after using preallocate: false, so it's not ideally reliable. XLA even does memory reservations at compilation time, my guess is that it runs some example operations to pick preferable algorithm or fine tune algorithm parameters. That said, it seems clear that bf16 reduces both memory and time roughly by a factor of 2. Weirdly, lazy transfers seem to bump the memory usage (but it doesn't mean that much memory is required in practice, it's just XLA bumping the reservation, see below).

Source (first row)
# Stable Diffusion testing

```elixir
Mix.install([
  {:nx, github: "elixir-nx/nx", sparse: "nx", override: true},
  {:exla, github: "elixir-nx/nx", sparse: "exla", override: true},
  {:axon, github: "elixir-nx/axon", override: true},
  {:kino, "~> 0.11.3"},
  {:bumblebee, github: "elixir-nx/bumblebee"}
])

Application.put_env(:exla, :clients,
  host: [platform: :host],
  cuda: [platform: :cuda, preallocate: false]
  # cuda: [platform: :cuda, memory_fraction: 0.3]
  # cuda: [platform: :cuda]
)

Application.put_env(:exla, :preferred_clients, [:cuda, :host])

Nx.global_default_backend({EXLA.Backend, client: :host})
```

## init

```elixir
with {output, 0} <- System.shell("nvidia-smi --query-gpu=memory.total,memory.used --format=csv") do
  IO.puts(output)
end
```

<!-- livebook:{"branch_parent_index":0} -->

## Stable Diffusion fp16

```elixir
repository_id = "CompVis/stable-diffusion-v1-4"

{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai/clip-vit-large-patch14"})

{:ok, clip} =
  Bumblebee.load_model({:hf, repository_id, subdir: "text_encoder"},
    params_variant: "fp16",
    type: :bf16
  )

{:ok, unet} =
  Bumblebee.load_model({:hf, repository_id, subdir: "unet"},
    params_variant: "fp16",
    type: :bf16
  )

{:ok, vae} =
  Bumblebee.load_model({:hf, repository_id, subdir: "vae"},
    architecture: :decoder,
    params_variant: "fp16",
    type: :bf16
  )

{:ok, scheduler} = Bumblebee.load_scheduler({:hf, repository_id, subdir: "scheduler"})

clip = update_in(clip.params, &Nx.backend_copy(&1, {EXLA.Backend, client: :cuda}))
unet = update_in(unet.params, &Nx.backend_copy(&1, {EXLA.Backend, client: :cuda}))
vae = update_in(vae.params, &Nx.backend_copy(&1, {EXLA.Backend, client: :cuda}))

serving =
  Bumblebee.Diffusion.StableDiffusion.text_to_image(clip, unet, vae, tokenizer, scheduler,
    num_steps: 20,
    num_images_per_prompt: 1,
    compile: [batch_size: 1, sequence_length: 60],
    defn_options: [compiler: EXLA]
  )

Kino.start_child({Nx.Serving, name: SD, serving: serving})
```

```elixir
prompt = "numbat, forest, high quality, detailed, digital art"

output = Nx.Serving.batched_run(SD, prompt)

for result <- output.results do
  Kino.Image.new(result.image)
end
|> Kino.Layout.grid(columns: 2)
```

jonatanklosko avatar Dec 19 '23 09:12 jonatanklosko

I experimented with different values of memory_fraction as an upper limit. For the first entry in the table above:

  • lazy_transfers: :always - 3GiB (3.4s)
  • manual backend_copy - 4.6GiB (0.7s)
  • preallocate_params: true it's 6.2GiB (0.7s)

So lazy transfers do help a bit, but imply a significant slowdown.

What's interesting though is that preallocate_params requires more memory than manual backend_copy. It's even more surprising given that the OOM happens at serving runtime, not during the params preallocation.

jonatanklosko avatar Dec 19 '23 10:12 jonatanklosko

preallocate/jit will transfer the data twice, one as arguments, one as return type. So we probably need a new callback/abstraction to make this easier :D

josevalim avatar Dec 19 '23 11:12 josevalim

FTR fixed in #317, now preallocate_params: true effectively does backend_copy :)

jonatanklosko avatar Dec 19 '23 16:12 jonatanklosko

I have added an entry for LCM+Lora, @wtedw may have input here (and we may need to update/release a Axon before). /cc @seanmor5

josevalim avatar Jan 20 '24 18:01 josevalim

I think we should update Axon to better support LoRA, I have a draft in place right now but I have to revisit it to make it work as I intend :)

seanmor5 avatar Jan 21 '24 14:01 seanmor5

LCM just adapts these nodes in the unet model: https://github.com/wtedw/lorax/blob/main/lib/lorax/lcm.ex#L121-L139 The weights can be found here: https://huggingface.co/latent-consistency/lcm-lora-sdv1-5

For Bumblebee, (if trying to make it compatible w/ most LoRA files in HuggingFace)

  • Needs to manually parse through the lora file to infer how to adapt the model layers This includes knowing which layers to inject, the lora rank, and the lora alpha. Unfortunately LCM's HF page doesn't come with a "lora_config" file, but from a quick glance, some models come with this "adapter_config" file. Not sure how common this is though: https://huggingface.co/IlyaGusev/saiga_13b_lora/blob/main/adapter_config.json.
  • The LCM lora was trained with something called Kohya: https://github.com/bmaltais/kohya_ss. It has this layer naming scheme: https://github.com/wtedw/lorax/blob/main/lib/lorax/lcm.ex#L147. I believe this is the most common trainer that's used.

If you guys need any PRs, lmk!

wtedw avatar Jan 21 '24 17:01 wtedw

Just a heads up that Stability AI just announced Stable Diffusion 3, so that makes us wonder how much effort we should pour into SD vs SDXL vs SD3. It still probably makes sense to support LoRA on Stable Diffusion, because that will require improvements in Axon and elsewhere that we could use for other models, but custom schedulers and token merging is up to debate at the moment.

josevalim avatar Feb 23 '24 18:02 josevalim

Checking off attention slicing, it has actually been removed from diffusers docs (https://github.com/huggingface/diffusers/issues/4487) because of flash attention. Either way, the trick is about slicing a dimension and using a while loop, which is similar to flash attention on defn level (as opposed to custom CUDA kernel), and that didn't turn out to be beneficial.

jonatanklosko avatar Feb 26 '24 09:02 jonatanklosko

The main part of StableDiffusion is iterative U-Net model pass, which happens for a specified number of timesteps. DeepCache is about reusing some of the intermediate layer outputs across some diffusion iterations, that is outputs expected to change slowly over time.

This technique is not going to reduce memory usage, because we still need to periodically do a uncached model pass. Given that we need to keep the cached intermediate results, it can increase the usage if anything. It can have a significant speedup, assuming we do a fair amount of steps. For SD Turbo or LCM, where we do 1 or at most a few steps, the caching is not applicable.

So this may be something we want to explore in the future, depending on SD3 and other research going forward, but I don't think it's immediately relevant for us now.

jonatanklosko avatar Feb 26 '24 10:02 jonatanklosko