diffusers
diffusers copied to clipboard
[core] support device type device_maps to work with offloading.
What does this PR do?
This PR allows users to pass a device_map="cpu" while initializing a pipeline and then enable model CPU offloading.
This is beneficial when users want to initialize the models on CPU (think of low VRAM environments) and then call enable_model_cpu_offload(). Quantized models initialize directly on a supported accelerator. This can lead to OOMs.
Below provides a diff that this PR introduces:
import torch
from diffusers import Flux2Pipeline, AutoModel
from transformers import Mistral3ForConditionalGeneration
repo_id = "diffusers/FLUX.2-dev-bnb-4bit" # quantized text-encoder and DiT. VAE still in bf16
device = "cuda:0"
torch_dtype = torch.bfloat16
- text_encoder = Mistral3ForConditionalGeneration.from_pretrained(
- repo_id, subfolder="text_encoder", torch_dtype=torch.bfloat16, device_map="cpu"
- )
- dit = AutoModel.from_pretrained(
- repo_id, subfolder="transformer", torch_dtype=torch.bfloat16, device_map="cpu"
- )
- pipe = Flux2Pipeline.from_pretrained(
- repo_id, text_encoder=text_encoder, transformer=dit, torch_dtype=torch_dtype
- )
- pipe.enable_model_cpu_offload()
+ pipe = Flux2Pipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16, device_map="cpu")
+ pipe.enable_model_cpu_offload()
prompt = "Realistic macro photograph of a hermit crab using a soda can as its shell, partially emerging from the can, captured with sharp detail and natural colors, on a sunlit beach with soft shadows and a shallow depth of field, with blurred ocean waves in the background. The can has the text `BFL + Diffusers` on it and it has a color gradient that start with #FF5733 at the top and transitions to #33FF57 at the bottom."
image = pipe(
prompt=prompt,
generator=torch.Generator(device=device).manual_seed(42),
num_inference_steps=50,
guidance_scale=4,
).images[0]
image.save("flux2_output.png")
cc: @asomoza @apolinario
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.