diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

[core] support device type device_maps to work with offloading.

Open sayakpaul opened this issue 1 month ago • 1 comments

What does this PR do?

This PR allows users to pass a device_map="cpu" while initializing a pipeline and then enable model CPU offloading.

This is beneficial when users want to initialize the models on CPU (think of low VRAM environments) and then call enable_model_cpu_offload(). Quantized models initialize directly on a supported accelerator. This can lead to OOMs.

Below provides a diff that this PR introduces:

import torch
from diffusers import Flux2Pipeline, AutoModel
from transformers import Mistral3ForConditionalGeneration

repo_id = "diffusers/FLUX.2-dev-bnb-4bit" # quantized text-encoder and DiT. VAE still in bf16
device = "cuda:0"
torch_dtype = torch.bfloat16

- text_encoder = Mistral3ForConditionalGeneration.from_pretrained(
-     repo_id, subfolder="text_encoder", torch_dtype=torch.bfloat16, device_map="cpu"
- )
- dit = AutoModel.from_pretrained(
-     repo_id, subfolder="transformer", torch_dtype=torch.bfloat16, device_map="cpu"
- )
- pipe = Flux2Pipeline.from_pretrained(
-     repo_id, text_encoder=text_encoder, transformer=dit, torch_dtype=torch_dtype
- )
- pipe.enable_model_cpu_offload()
+ pipe = Flux2Pipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16, device_map="cpu")
+ pipe.enable_model_cpu_offload()

prompt = "Realistic macro photograph of a hermit crab using a soda can as its shell, partially emerging from the can, captured with sharp detail and natural colors, on a sunlit beach with soft shadows and a shallow depth of field, with blurred ocean waves in the background. The can has the text `BFL + Diffusers` on it and it has a color gradient that start with #FF5733 at the top and transitions to #33FF57 at the bottom."
image = pipe(
    prompt=prompt,
    generator=torch.Generator(device=device).manual_seed(42),
    num_inference_steps=50,
    guidance_scale=4,
).images[0]

image.save("flux2_output.png")

cc: @asomoza @apolinario

sayakpaul avatar Dec 09 '25 05:12 sayakpaul

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.