diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

The Modular Diffusers

Open yiyixuxu opened this issue 1 year ago • 40 comments

Notes

  • overrided this PR https://github.com/huggingface/diffusers/pull/11652/files, need to make sure new code works for this case

Getting Started with Modular Diffusers

With Modular Diffusers, we introduce a unified pipeline system that simplifies how you work with diffusion models. Instead of creating separate pipelines for each task, Modular Diffusers let you:

Write Only What's New: You won't need to rewrite the entire pipeline from scratch. You can create pipeline blocks just for your new workflow's unique aspects and reuse existing blocks for existing functionalities.

Assemble Like LEGO®: You can mix and match blocks in flexible ways. This allows you to write dedicated blocks for specific workflows, and then assemble different blocks into a pipeline that that can be used more conveniently for multiple workflows. Here we will walk you through how to use a pipeline like this we built with Modular diffusers! In later sections, we will also go over how to assemble and build new pipelines!

Quick Start with StableDiffusionXLAutoPipeline

from diffusers import ModularPipeline, StableDiffusionXLAutoPipeline
from diffusers.pipelines.components_manager import ComponentsManager

# Load models
components = ComponentsManager()
components.add_from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
components.enable_auto_cpu_offload(device="cuda:0")

# Create pipeline
auto_pipe = ModularPipeline.from_block(StableDiffusionXLAutoPipeline())
auto_pipe.update_states(**components.components)

Auto Workflow Selection

The pipeline automatically adapts to your inputs:

  • Basic text-to-image: Just provide a prompt
  • Image-to-image: Add an image input
  • Inpainting: Add both image and mask_image
  • ControlNet: Add a control_image
  • And more!

Auto Documentations

We care a great deal about documentation here at Diffusers, and Modular Diffusers carries this mission forward. All our pipeline blocks comes with complete docstrings that automatically compose as you build your pipelines. This means

  • Every pipeline you build with Modular diffusers come with complete documentation automatically
  • Input/output signatures are dynamically generated, same goes for components and configurations
  • Parameter descriptions and types are included
  • Block relationships and dependencies are documented as well

inspect your pipeline

# get pipeline info components/configurations/pipeline blocks/ docstring
print(auto_pipe)
see an example of output
ModularPipeline:
==============================

Pipeline Block:
--------------
StableDiffusionXLAutoPipeline
 (Class: SequentialPipelineBlocks)
  • text_encoder (StableDiffusionXLTextEncoderStep)
  • ip_adapter (StableDiffusionXLAutoIPAdapterStep)
  • image_encoder (StableDiffusionXLAutoVaeEncoderStep)
  • before_denoise (StableDiffusionXLAutoBeforeDenoiseStep)
  • denoise (StableDiffusionXLAutoDenoiseStep)
  • decode (StableDiffusionXLAutoDecodeStep)

Registered Components:
----------------------
text_encoder: CLIPTextModel (dtype=torch.float16, device=cpu)
text_encoder_2: CLIPTextModelWithProjection (dtype=torch.float16, device=cpu)
tokenizer: CLIPTokenizer
tokenizer_2: CLIPTokenizer
image_encoder: CLIPVisionModelWithProjection (dtype=torch.float16, device=cpu)
feature_extractor: CLIPImageProcessor
unet: UNet2DConditionModel (dtype=torch.float16, device=cpu)
vae: AutoencoderKL (dtype=torch.float16, device=cpu)
scheduler: EulerDiscreteScheduler
controlnet: ControlNetModel (dtype=torch.float16, device=cpu)
guider: CFGGuider
controlnet_guider: CFGGuider

Registered Configs:
------------------
force_zeros_for_empty_prompt: True
requires_aesthetics_score: False

------------------
This pipeline contains blocks that are selected at runtime based on inputs.

Trigger Inputs: {'control_image', 'control_mode', 'image_latents', 'padding_mask_crop', 'mask_image', 'ip_adapter_image', 'image', 'mask'}
  Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('control_image')`).
Check `.doc` of returned object for more information.

  Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion XL.
  - for image-to-image generation, you need to provide either `image` or `image_latents`
  - for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` 
  - to run the controlnet workflow, you need to provide `control_image`
  - to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`
  - to run the ip_adapter workflow, you need to provide `ip_adapter_image`
  - for text-to-image generation, all you need to provide is `prompt`

  Args:

      prompt (`Union[str, List]`, *optional*):
          The prompt or prompts to guide the image generation.

      prompt_2 (`Union[str, List]`, *optional*):
          The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in
          both text-encoders

      negative_prompt (`Union[str, List]`, *optional*):
          The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if
          `guidance_scale` is less than `1`).

      negative_prompt_2 (`Union[str, List]`, *optional*):
          The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not
          defined, `negative_prompt` is used in both text-encoders

      cross_attention_kwargs (`Union[dict, NoneType]`, *optional*):
          A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor`
          in [diffusers.models.attention_processor]

      guidance_scale (`float`, *optional*, defaults to 5.0):
          Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
          `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance
          scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are
          closely linked to the text `prompt`, usually at the expense of lower image quality.

      clip_skip (`Union[int, NoneType]`, *optional*):

      ip_adapter_image (`Union[Image, ndarray, Tensor, List, List, List]`):
          The image(s) to be used as ip adapter

      height (`Union[int, NoneType]`, *optional*):
          The height in pixels of the generated image. This is set to 1024 by default for the best results. Anything below
          512 pixels won't work well for
          [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and
          checkpoints that are not specifically fine-tuned on low resolutions.

      width (`Union[int, NoneType]`, *optional*):
          The width in pixels of the generated image. This is set to 1024 by default for the best results. Anything below 512
          pixels won't work well for
          [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and
          checkpoints that are not specifically fine-tuned on low resolutions.

      generator (`Union[Generator, List, NoneType]`, *optional*):
          One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
          generation deterministic.

      image (`Union[Image, ndarray, Tensor, List, List, List]`):
          The image(s) to modify with the pipeline, for img2img or inpainting task. When using for inpainting task, parts of
          the image will be masked out with `mask_image` and repainted according to `prompt`.

      mask_image (`Union[Image, ndarray, Tensor, List, List, List]`, *optional*):
          `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be repainted, while
          black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single channel
          (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the expected
          shape would be `(B, H, W, 1)`.

      padding_mask_crop (`Union[Tuple, NoneType]`, *optional*):
          The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and
          mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect
          ratio of the image and contains all masked area, and then expand that area based on `padding_mask_crop`. The image
          and mask_image will then be cropped based on the expanded area before resizing to the original image size for
          inpainting. This is useful when the masked area is small while the image is large and contain information
          irrelevant for inpainting, such as background.

      num_images_per_prompt (`int`, *optional*, defaults to 1):
          The number of images to generate per prompt.

      num_inference_steps (`int`, *optional*, defaults to 50):
          The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower
          inference.

      timesteps (`Union[Tensor, NoneType]`, *optional*):
          Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their
          `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used.
          Must be in descending order.

      sigmas (`Union[Tensor, NoneType]`, *optional*):
          Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their
          `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used.

      denoising_end (`Union[float, NoneType]`, *optional*):
          When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be completed before
          it is intentionally prematurely terminated. As a result, the returned sample will still retain a substantial amount
          of noise as determined by the discrete timesteps selected by the scheduler. The denoising_end parameter should
          ideally be utilized when this pipeline forms a part of a 'Mixture of Denoisers' multi-pipeline setup.

      strength (`float`, *optional*, defaults to 0.3):
          Conceptually, indicates how much to transform the reference `image` (the masked portion of image for inpainting).
          Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
          `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1,
          added noise will be maximum and the denoising process will run for the full number of iterations specified in
          `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of
          `denoising_start` being declared as an integer, the value of `strength` will be ignored.

      denoising_start (`Union[float, NoneType]`, *optional*):
          The denoising start value to use for the scheduler. Determines the starting point of the denoising process.

      latents (`Union[Tensor, NoneType]`, *optional*):
          Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can
          be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by
          sampling using the supplied random `generator`.

      original_size (`Tuple`, *optional*, defaults to (1024, 1024)):
          The original size (height, width) of the image that conditions the generation process. If different from
          target_size, the image will appear to be down- or upsampled. Part of SDXL's micro-conditioning as explained in
          section 2.2 of https://huggingface.co/papers/2307.01952

      target_size (`Tuple`, *optional*, defaults to (1024, 1024)):
          The target size (height, width) of the generated image. For most cases, this should be set to the desired output
          dimensions. Part of SDXL's micro-conditioning as explained in section 2.2 of
          https://huggingface.co/papers/2307.01952

      negative_original_size (`Tuple`, *optional*, defaults to (1024, 1024)):
          The negative original size to condition against during generation. Part of SDXL's micro-conditioning as explained
          in section 2.2 of https://huggingface.co/papers/2307.01952. See:
          https://github.com/huggingface/diffusers/issues/4208

      negative_target_size (`Tuple`, *optional*, defaults to (1024, 1024)):
          The negative target size to condition against during generation. Should typically match target_size. Part of SDXL's
          micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952. See:
          https://github.com/huggingface/diffusers/issues/4208

      crops_coords_top_left (`Tuple`, *optional*, defaults to (0, 0)):
          `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
          `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning

      negative_crops_coords_top_left (`Tuple`, *optional*, defaults to (0, 0)):
          To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
          micro-conditioning

      aesthetic_score (`float`, *optional*, defaults to 6.0):
          Used to simulate an aesthetic score of the generated image by influencing the positive text condition. Part of
          SDXL's micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952

      negative_aesthetic_score (`float`, *optional*, defaults to 2.0):
          Part of SDXL's micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952. Can be
          used to simulate an aesthetic score of the generated image by influencing the negative text condition.

      control_image (`Union[Image, ndarray, Tensor, List, List, List]`, *optional*):
          The ControlNet input condition to provide guidance to the unet for generation. If passed as torch.Tensor, it is
          used as-is. PIL.Image.Image inputs are accepted and default to image dimensions. For multiple ControlNets, pass
          images as a list for proper batching.

      control_guidance_start (`Union[float, List]`, *optional*, defaults to 0.0):
          The percentage of total steps at which the ControlNet starts applying.

      control_guidance_end (`Union[float, List]`, *optional*, defaults to 1.0):
          The percentage of total steps at which the ControlNet stops applying.

      control_mode (`List`, *optional*):
          The control mode for union controlnet, 0 for openpose, 1 for depth, 2 for hed/pidi/scribble/ted, 3 for
          canny/lineart/anime_lineart/mlsd, 4 for normal and 5 for segment

      controlnet_conditioning_scale (`Union[float, List]`, *optional*, defaults to 1.0):
          Scale factor for ControlNet outputs before adding to unet residual. For multiple ControlNets, can be set as a list
          of scales.

      guess_mode (`bool`, *optional*, defaults to False):
          Enables ControlNet encoder to recognize input image content without prompts. Recommended guidance_scale: 3.0-5.0.

      guidance_rescale (`float`, *optional*, defaults to 0.0):
          Guidance rescale factor (φ) to fix overexposure when using zero terminal SNR, as proposed in 'Common Diffusion
          Noise Schedules and Sample Steps are Flawed'.

      eta (`float`, *optional*, defaults to 0.0):
          Parameter η in the DDIM paper. Only applies to DDIMScheduler, ignored for others.

      guider_kwargs (`Union[Dict, NoneType]`, *optional*):
          Optional kwargs dictionary passed to the Guider.

      output_type (`str`, *optional*, defaults to pil):
          The output format of the generated image. Choose between PIL (PIL.Image.Image), torch.Tensor or np.array.

      return_dict (`bool`, *optional*, defaults to True):
          Whether or not to return a StableDiffusionXLPipelineOutput instead of a plain tuple.

      dtype (`dtype`, *optional*):
          The dtype of the model inputs

      preprocess_kwargs (`Union[dict, NoneType]`, *optional*):
          A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under
          `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]

      ip_adapter_embeds (`List`, *optional*):
          Pre-generated image embeddings for IP-Adapter. Can be generated from ip_adapter step.

      negative_ip_adapter_embeds (`List`, *optional*):
          Pre-generated negative image embeddings for IP-Adapter. Can be generated from ip_adapter step.

      image_latents (`Tensor`, *optional*):
          The latents representing the reference image for image-to-image/inpainting generation. Can be generated in
          vae_encode step.

      mask (`Tensor`, *optional*):
          The mask for the inpainting generation. Can be generated in vae_encode step.

      masked_image_latents (`Tensor`, *optional*):
          The masked image latents for the inpainting generation (only for inpainting-specific unet). Can be generated in
          vae_encode step.

      image_latents (`Union[Tensor, NoneType]`, *optional*):
          The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in
          vae_encode or prepare_latent step.

      crops_coords (`Union[Tuple, NoneType]`, *optional*):
          The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be
          generated in vae_encode or prepare_latent step.

      crops_coords (`Tuple`, *optional*):
          The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be
          generated in vae_encode step.

  Returns:

      images (`Union[List, List, List]`):
          The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array

use get_execution_blocks to see which blocks will run for your inputs/workflow, for example, if you want to run a text-to-image controlnet workflow, you can do this

print(auto_pipe.get_execution_blocks("control_image"))

see the docstring relevant to your inputs/workflow

print(auto_pipe.get_execution_blocks("control_image").doc)

Advanced Workflows

Once you've created the auto pipeline, you can use it for different features as long as you add the required components and pass the required inputs.

# Add ControlNet
auto_pipe.update_states(controlnet=controlnet)

# Enable IP-Adapter
auto_pipe.update_states(image_encoder=..., feature_extractor=...)
auto_pipe.load_ip_adapter("h94/IP-Adapter")

# Add LoRA
auto_pipe.load_lora_weights(...)

# at inference time, pass all the inputs required for your workflow
images = auto_pipe(
    prompt="..",
    control_image=pose_image,        # this trigger the ControlNet workflow
    ip_adapter_image=style_image,    # this trigger the ip-adapter workflow
    ...
).images

Here is an example you can run for a more complex workflow using controlnet/IP-Adapter/Lora/PAG

from diffusers import ControlNetModel
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
from diffusers.utils import load_image
from diffusers.guider import PAGGuider

# load controlnet
controlnet = ControlNetModel.from_pretrained("thibaud/controlnet-openpose-sdxl-1.0", torch_dtype=dtype)
components.add("controlnet", controlnet)

# load image_encoder for ip adapter
image_encoder = CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter", subfolder="sdxl_models/image_encoder", torch_dtype=dtype)
feature_extractor = CLIPImageProcessor(size=224, crop_size=224)

components.add("image_encoder", image_encoder)
components.add("feature_extractor", feature_extractor)

# load additional components into the pipeline
auto_pipe.update_states(**components.get(["controlnet", "image_encoder", "feature_extractor"]))

# load ip adapter
auto_pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
auto_pipe.set_ip_adapter_scale(0.6)

# let's also load a lora while we're at it
auto_pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy_face")

# let's also throw PAG in there because why not!
pag_guider = PAGGuider(pag_applied_layers="mid")
controlnet_pag_guider = PAGGuider(pag_applied_layers="mid")
auto_pipe.update_states(guider=pag_guider, controlnet_guider=controlnet_pag_guider)

# prepare inputs
prompt = "an astronaut"
control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/controlnet/person_pose.png")
ip_adapter_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img5.png")

# Run pipeline with everything combined
images = auto_pipe(
    prompt=prompt,
    control_image=control_image,
    ip_adapter_image=ip_adapter_image,
    output="images"
).images
images[0]

yiyi_modular_out

check out more usage examples here

test1: complete testing script for `StableDiffusionXLAutoPipeline`
import os
import torch
import numpy as np
import cv2
from PIL import Image

from diffusers import (
    ControlNetModel,
    ModularPipeline,
    UNet2DConditionModel,
    AutoencoderKL,
    ControlNetUnionModel,
)
from diffusers.utils import load_image
from diffusers.guider import PAGGuider, CFGGuider, APGGuider
from diffusers.pipelines.components_manager import ComponentsManager
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_modular import StableDiffusionXLAutoPipeline, StableDiffusionXLIPAdapterStep

from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor

from controlnet_aux import LineartAnimeDetector

import logging
logging.getLogger().setLevel(logging.INFO)
logging.getLogger("diffusers").setLevel(logging.INFO)


# define device and dtype
device = "cuda:0"
dtype = torch.float16
num_images_per_prompt = 1

test_pag = True
test_lora = False


# define output folder
out_folder = "modular_test_outputs_0131_auto_pipeline"
if os.path.exists(out_folder):
    # Remove all files in the directory
    for file in os.listdir(out_folder):
        file_path = os.path.join(out_folder, file)
        try:
            if os.path.isfile(file_path):
                os.unlink(file_path)
        except Exception as e:
            print(f"Error: {e}")
os.makedirs(out_folder, exist_ok=True)

# functions for memory info
def reset_memory():
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

def clear_memory():
    torch.cuda.empty_cache()

def print_mem(mem_size, name):
    mem_gb = mem_size / 1024**3
    mem_mb = mem_size / 1024**2
    print(f"- {name}: {mem_gb:.2f} GB ({mem_mb:.2f} MB)")

def print_memory(message=None):
    """
    Print detailed GPU memory statistics for a specific device.
    
    Args:
        device_id (int): GPU device ID
    """
    allocated_mem = torch.cuda.memory_allocated(device)
    reserved_mem = torch.cuda.memory_reserved(device)
    mem_on_device = torch.cuda.mem_get_info(device)[0]
    peak_mem = torch.cuda.max_memory_allocated(device)

    print(f"\nGPU:{device} Memory Status {message}:")
    print_mem(allocated_mem, "allocated memory")
    print_mem(reserved_mem, "reserved memory")
    print_mem(peak_mem, "peak memory")
    print_mem(mem_on_device, "mem on device")

# function to make canny image (for controlnet)
def make_canny(image):
    image = np.array(image)
    image = cv2.Canny(image, 100, 200)
    image = image[:, :, None]
    image = np.concatenate([image, image, image], axis=2)
    return Image.fromarray(image)


# (1)Define inputs
# for text2img/img2img
prompt = "a bear sitting in a chair drinking a milkshake"
negative_prompt = "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality"

# for img2img
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
init_image = load_image(url).convert("RGB")
strength = 0.9 

# for controlnet
control_image = make_canny(init_image)
controlnet_conditioning_scale = 0.5  # recommended for good generalization
# for controlnet_union
processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
controlnet_union_image = processor(init_image, output_type="pil")

# for inpainting
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"

inpaint_image = load_image(img_url).resize((1024, 1024))
inpaint_mask = load_image(mask_url).resize((1024, 1024))
inpaint_control_image = make_canny(inpaint_image)
inpaint_strength = 0.99

# for ip adapter
ip_adapter_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_diner.png")


# (2) define blocks and nodes(builder)      

auto_pipeline_block = StableDiffusionXLAutoPipeline()
auto_pipeline = ModularPipeline.from_block(auto_pipeline_block)
refiner_pipeline = ModularPipeline.from_block(auto_pipeline_block)



# (3) add states to nodes
repo = "stabilityai/stable-diffusion-xl-base-1.0"
refiner_repo = "stabilityai/stable-diffusion-xl-refiner-1.0"
controlnet_repo = "diffusers/controlnet-canny-sdxl-1.0"
inpaint_repo = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
vae_fix_repo = "madebyollin/sdxl-vae-fp16-fix"
controlnet_union_repo = "brad-twinkl/controlnet-union-sdxl-1.0-promax"
ip_adapter_repo = "h94/IP-Adapter"


components = ComponentsManager()
components.add_from_pretrained(repo, torch_dtype=dtype)


controlnet = ControlNetModel.from_pretrained(controlnet_repo, torch_dtype=dtype)
components.add("controlnet", controlnet)

image_encoder = CLIPVisionModelWithProjection.from_pretrained(ip_adapter_repo, subfolder="sdxl_models/image_encoder", torch_dtype=dtype)
feature_extractor = CLIPImageProcessor(size=224, crop_size=224)

components.add("image_encoder", image_encoder)
components.add("feature_extractor", feature_extractor)


# load components/config into nodes
auto_pipeline.update_states(**components.components)


# load other componetns for swap later
refiner_unet = UNet2DConditionModel.from_pretrained(refiner_repo, subfolder="unet", torch_dtype=dtype)
inpaint_unet = UNet2DConditionModel.from_pretrained(inpaint_repo, subfolder="unet", torch_dtype=dtype)
vae_fix = AutoencoderKL.from_pretrained(vae_fix_repo, torch_dtype=dtype)
controlnet_union = ControlNetUnionModel.from_pretrained(controlnet_union_repo, torch_dtype=dtype)

components.add("refiner_unet", refiner_unet)
components.add("inpaint_unet", inpaint_unet)
components.add("controlnet_union", controlnet_union)
components.add("vae_fix", vae_fix)


# you can add guiders to manager too but no need because it was not serialized
pag_guider = PAGGuider(pag_applied_layers="mid")
controlnet_pag_guider = PAGGuider(pag_applied_layers="mid")
cfg_guider = CFGGuider()
controlnet_cfg_guider = CFGGuider()


# (4) enable auto cpu offload: automatically offload models when available gpu memory go below a certain threshold
components.enable_auto_cpu_offload(device=device)
print(components)
reset_memory()



# using auto_pipeline to generate images

# to get info about auto_pipeline and how to use it: inputs/outputs/components
# this is an "auto" workflow that works for all use cases: text2img, img2img, inpainting, controlnet, etc.
print(f" ")
print(f" auto_pipeline:")
print(auto_pipeline)
print(" ")


# since we want to use text2img use case, we can run the following to see components/blocks/inputs for this use case
print(f" ")
print(f" auto_pipeline info (default use case: text2img)")
print(auto_pipeline.get_execution_blocks())
print(" ")

# test1: text2img use case
# when you run the auto workflow, you will get these logs telling you which blocks are actuallyrunning
# (should match what the sdxl_node told you)
# Running block: StableDiffusionXLBeforeDenoiseStep, trigger: None
# Running block: StableDiffusionXLDenoiseStep, trigger: None
# Running block: StableDiffusionXLDecodeStep, trigger: None

generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator, 
    output="images"
)
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test1_out_text2img_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test1_out_text2img.png")

clear_memory()


# test2: text2img with lora use case
print(f" ")
print(f" running test2: text2img with lora use case")
generator = torch.Generator(device="cuda").manual_seed(0)
auto_pipeline.load_lora_weights("rajkumaralma/dissolve_dust_style", weight_name="ral-dissolve-sdxl.safetensors", adapter_name="ral-dissolve")
images_output = auto_pipeline(
    prompt=prompt, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator, 
    output="images"
)
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test2_out_text2img_lora_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test2_out_text2img_lora.png")


# test3:text2image with pag
print(f" ")
print(f" running test3:text2image with pag")
if not test_lora:
    auto_pipeline.unload_lora_weights()

auto_pipeline.update_states(guider=pag_guider, controlnet_guider=controlnet_pag_guider)
generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator, 
    guider_kwargs={"pag_scale": 3.0},
    output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test3_out_text2img_pag_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test3_out_text2img_pag.png")

clear_memory()
# checkout the components if you want, the models used is moved to devicem some might get offloaded to cpu
# print(components)


# test4: SDXL(text2img) with ip_adapter+ pag?
print(f" ")
print(f" running test4: SDXL(text2img) with ip_adapter")

auto_pipeline.load_ip_adapter(ip_adapter_repo, subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
auto_pipeline.set_ip_adapter_scale(0.6)

generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator, 
    ip_adapter_image=ip_adapter_image,
    output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test  4_out_text2img_ip_adapter_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test4_out_text2img_ip_adapter.png")

auto_pipeline.unload_ip_adapter()
clear_memory()

# test5: SDXL(text2img) with controlnet

if not test_pag:
    auto_pipeline.update_states(guider=cfg_guider, controlnet_guider=controlnet_cfg_guider)
    guider_kwargs = {}
else:
    guider_kwargs = {"pag_scale": 3.0}


# we are going to pass a new input now `control_image` so the workflow will be automatically converted to controlnet use case
# let's checkout the info for controlnet use case
print(f" auto_pipeline info (controlnet use case)")
print(auto_pipeline.get_execution_blocks("control_image"))
print(" ")

print(f" ")
print(f" running test5: SDXL(text2img) with controlnet")

generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    control_image=control_image, 
    controlnet_conditioning_scale=controlnet_conditioning_scale, 
    guider_kwargs=guider_kwargs, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator,
    output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test5_out_text2img_control_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test5_out_text2img_control.png")

clear_memory()

# test6: SDXL(img2img)

print(f" ")
print(f" running test6: SDXL(img2img)")

generator = torch.Generator(device="cuda").manual_seed(0)

# let's checkout the sdxl_node info for img2img use case
print(f" auto_pipeline info (img2img use case)")
print(auto_pipeline.get_execution_blocks("image"))
print(" ")

images_output = auto_pipeline(
    prompt=prompt, 
    image=init_image, 
    strength=strength, 
    num_images_per_prompt=num_images_per_prompt,
    guider_kwargs=guider_kwargs, 
    generator=generator, 
    output="images"
)
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test6_out_img2img_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test6_out_img2img.png")

clear_memory()


# test7: SDXL(img2img) with controlnet
# let's checkout the sdxl_node info for img2img controlnet use case
print(f" sdxl_node info (img2img controlnet use case)")
print(auto_pipeline.get_execution_blocks("image", "control_image"))
print(" ")

print(f" ")
print(f" running test7: SDXL(img2img) with controlnet")

generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    image=init_image, 
    strength=strength, 
    num_images_per_prompt=num_images_per_prompt,
    guider_kwargs=guider_kwargs, 
    control_image=control_image, 
    controlnet_conditioning_scale=controlnet_conditioning_scale, 
    generator=generator, 
    output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test7_out_img2img_control_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test7_out_img2img_control.png")

clear_memory()

# test8: img2img with refiner

refiner_pipeline.update_states(**components.get(["text_encoder_2","tokenizer_2", "vae", "scheduler"]), unet=components.get("refiner_unet"), force_zeros_for_empty_prompt=True, requires_aesthetics_score=True)
# let's checkout the refiner_node
print(f" refiner_pipeline info")
print(refiner_pipeline)
print(f" ")

print(f" refiner_pipeline: triggered by `image_latents`")
print(refiner_pipeline.get_execution_blocks("image_latents"))
print(" ")

print(f" running test8: img2img with refiner")


generator = torch.Generator(device="cuda").manual_seed(0)
latents = auto_pipeline(
    prompt=prompt, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator, 
    output="latents"
)
images_output = refiner_pipeline(
    image_latents=latents,  
    prompt=prompt, 
    denoising_start=0.8, 
    generator=generator, 
    num_images_per_prompt=num_images_per_prompt,
    output="images"
)
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test8_out_img2img_refiner_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test8_out_img2img_refiner.png")

clear_memory()

# test9: SDXL(inpainting)
# let's checkout the sdxl_node info for inpainting use case
print(f" auto_pipeline info (inpainting use case)")
print(auto_pipeline.get_execution_blocks("mask_image", "image"))
print(" ")

print(f" ") 
print(f" running test9: SDXL(inpainting)")

generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    image=inpaint_image, 
    mask_image=inpaint_mask, 
    height=1024, 
    width=1024, 
    generator=generator, 
    num_images_per_prompt=num_images_per_prompt,
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    output="images"
)
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test9_out_inpainting_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test9_out_inpainting.png")

clear_memory()

# test10: SDXL(inpainting) with controlnet
# let's checkout the sdxl_node info for inpainting + controlnet use case
print(f" auto_pipeline info (inpainting + controlnet use case)")
print(auto_pipeline.get_execution_blocks("mask_image", "control_image"))
print(" ")

print(f" ") 
print(f" running test10: SDXL(inpainting) with controlnet")

generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    control_image=control_image, 
    image=init_image,
    height=1024,
    width=1024,
    mask_image=inpaint_mask, 
    guider_kwargs=guider_kwargs, 
    num_images_per_prompt=num_images_per_prompt,
    controlnet_conditioning_scale=controlnet_conditioning_scale, 
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    output="images"
)
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test10_out_inpainting_control_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test10_out_inpainting_control.png")

clear_memory()

# test11: SDXL(inpainting) with inpaint_unet
print(f" ") 
print(f" running test11: SDXL(inpainting) with inpaint_unet")

auto_pipeline.update_states(unet=components.get("inpaint_unet"))
generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    image=inpaint_image, 
    mask_image=inpaint_mask, 
    height=1024, 
    width=1024, 
    generator=generator, 
    num_images_per_prompt=num_images_per_prompt,
    output="images"
)
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test11_out_inpainting_inpaint_unet_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test11_out_inpainting_inpaint_unet.png")

clear_memory()


# test12: SDXL(inpainting) with inpaint_unet + padding_mask_crop
print(f" ") 
print(f" running test12: SDXL(inpainting) with inpaint_unet (padding_mask_crop=33)")

generator = torch.Generator(device="cuda").manual_seed(0)

images_output = auto_pipeline(
    prompt=prompt, 
    image=inpaint_image, 
    mask_image=inpaint_mask, 
    height=1024, 
    width=1024, 
    generator=generator, 
    padding_mask_crop=33, 
    num_images_per_prompt=num_images_per_prompt,
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test12_out_inpainting_inpaint_unet_padding_mask_crop_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test12_out_inpainting_inpaint_unet_padding_mask_crop.png")

clear_memory()


# test13: apg

print(f" ")
print(f" running test13: apg")

apg_guider = APGGuider()
auto_pipeline.update_states(guider=apg_guider, unet=components.get("unet"))


generator = torch.Generator().manual_seed(0)
images_output = auto_pipeline(
  prompt=prompt, 
  generator=generator,
  num_inference_steps=20,
  num_images_per_prompt=1, # yiyi: apg does not work with num_images_per_prompt > 1
  guidance_scale=15,
  height=896,
  width=768,
  guider_kwargs={
      "adaptive_projected_guidance_momentum": -0.5,
      "adaptive_projected_guidance_rescale_factor": 15.0,
  },
  output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test13_out_apg_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test13_out_apg.png")

clear_memory()


# test13: SDXL(text2img) with controlnet_union

auto_pipeline.update_states(controlnet=components.get("controlnet_union"), unet=components.get("unet"), vae=components.get("vae_fix"), guider=pag_guider, controlnet_guider=controlnet_pag_guider)
# we are going to pass a new input now `control_mode` so the workflow will be automatically converted to controlnet use case
# let's checkout the info for controlnet use case
print(f" auto_pipeline info (controlnet union use case)")
print(auto_pipeline.get_execution_blocks("control_mode"))
print(" ")

print(f" ")
print(f" running test14: SDXL(text2img) with controlnet_union")

generator = torch.Generator(device="cuda").manual_seed(0)

images_output = auto_pipeline(
    prompt=prompt, 
    control_mode=[3],
    control_image=[controlnet_union_image], 
    num_images_per_prompt=num_images_per_prompt,
    height=1024,
    width=1024,
    generator=generator,
    output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test14_out_text2img_control_union_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test14_out_text2img_control_union.png")

clear_memory()


# test14: SDXL(img2img) with controlnet_union

print(f" ")
print(f" auto_pipeline info (img2img controlnet union use case)")
print(auto_pipeline.get_execution_blocks("image", "control_mode"))
print(" ")

print(f" ")
print(f" running test15: SDXL(img2img) with controlnet_union")

generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    image=init_image, 
    generator=generator, 
    control_mode=[3], 
    control_image=[controlnet_union_image], 
    num_images_per_prompt=num_images_per_prompt, 
    height=1024, 
    width=1024, 
    output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test15_out_img2img_control_union_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test15_out_img2img_control_union.png")

clear_memory()

# test15: SDXL(inpainting) with controlnet_union
print(f" ")
print(f" auto_pipeline info (inpainting controlnet union use case)")
print(auto_pipeline.get_execution_blocks("mask", "control_mode"))
print(" ")

print(f" ")
print(f" running test16: SDXL(inpainting) with controlnet_union")

generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    image=init_image, 
    mask_image=inpaint_mask, 
    control_image=controlnet_union_image,
    control_mode=[3],
    height=1024, 
    width=1024, 
    generator=generator, 
    output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test16_out_inpainting_control_union_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test16_out_inpainting_control_union.png")

clear_memory()

print_memory("the end")

print(f" components info after the end")
print(components)

Modular Setup

StableDiffusionXLAutoPipeline is a very convenient preset; Just like the LEGO sets, you can break it down and reassemble and rearrange the pipeline blocks however you want. A more modular setup would look like this:


# AUTOBLOCK is a map of all the blocks we used to assemble `StableDiffusionXLAutoPipeline`
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_modular import AUTO_BLOCKS


# step1: create separate nodes to encode text/image/ip-adapter inputs
text_node = ModularPipeline.from_block(AUTO_BLOCKS.pop("text_encoder")()) 
image_node = ModularPipeline.from_block(AUTO_BLOCKS.pop("image_encoder")()) 
decoder_node = ModularPipeline.from_block(AUTO_BLOCKS.pop("decode")()) 

# make a node for "denoising", here we just use the leftover blocks
class SDXLAutoBlocks(SequentialPipelineBlocks):
    block_classes = list(AUTO_BLOCKS.values())
    block_names = list(AUTO_BLOCKS.keys())

sdxl_node = SDXLAutoBlocks()
# we can also use the same block to make a refiner node, but you need to load a different unet/config later with 
refiner_node = SDXLAutoBlocks()

# lora_node for lora related things
lora_node = ModularPipeline.from_block(StableDiffusionXLLoraStep())
# IPAdapater nodes for IPAdapter related things
ip_adapter_node = ModularPipeline.from_block(StableDiffusionXLIPAdapterStep())

# step2: load models into the nodes (sdxl_node and refiner nodes are made with same block but need different components)
...
sdxl_node.update_states(**components.get(["unet", "scheduler", "vae", "controlnet"]))
refiner_node.update_states(**components.get(["text_encoder_2","tokenizer_2", "vae", "scheduler"]), unet=components.get("refiner_unet"), force_zeros_for_empty_prompt=True, requires_aesthetics_score=True)
...

#step3:  generate embeddings to reuse them
text_state = text_node(prompt=,,,)
image_state = image_node(image=...)
ip_adapter_state = ip_adapter_node(...)

# step4: re-use embeddings in different workflows, change call parameters, or take the latent to use for a different workflow before decode
latents_img2img = sdxl_node(**text_state.intermediates, **image_state.intermediates, output="latents")
latents_text2img_28steps = sdxl_node(**text_state.intermediates, num_inference_steps = 28, ..., output="latents")
latents_text2img_ipa = sdxl_node(**text_state.intermedaites, **ip_adapter_embeddings, ..., output="latents)
latents_refined = refiner_node(**text_state.intermediates, image_latents=latents_xx, output="latents)
...

# step5: decode once it is ready to decode
image = decoder_node(latents=latents_refined, output="images").images
image[0]

With this setup, you precompute embeddings and reuse them across different denoise backends or with different inference parameters such as guidance_scale, num_inference_steps, or use different schedulers. You can modify your workflow by simply adding/removing/swapping blocks without recomputing the entire pipeline over and over again.

check out the full example script here

test2: modular setup This is the full testing script I used for more configuration, including inpainting/refiner/union controlnet/APG
import os
import torch
import numpy as np
import cv2
from PIL import Image

from diffusers import (
    ControlNetModel,
    ModularPipeline,
    UNet2DConditionModel,
    AutoencoderKL,
    ControlNetUnionModel,
)
from diffusers.utils import load_image
from diffusers.guider import PAGGuider, CFGGuider, APGGuider
from diffusers.pipelines.modular_pipeline import SequentialPipelineBlocks
from diffusers.pipelines.components_manager import ComponentsManager
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_modular import AUTO_BLOCKS, IMAGE2IMAGE_BLOCKS, StableDiffusionXLLoraStep

from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
from controlnet_aux import LineartAnimeDetector

import logging
logging.getLogger().setLevel(logging.INFO)
logging.getLogger("diffusers").setLevel(logging.INFO)


# define device and dtype
device = "cuda:0"
dtype = torch.float16
num_images_per_prompt = 1

test_pag = True
test_lora = False


# define output folder
out_folder = "modular_test_outputs"
if os.path.exists(out_folder):
    # Remove all files in the directory
    for file in os.listdir(out_folder):
        file_path = os.path.join(out_folder, file)
        try:
            if os.path.isfile(file_path):
                os.unlink(file_path)
        except Exception as e:
            print(f"Error: {e}")
os.makedirs(out_folder, exist_ok=True)

# functions for memory info
def reset_memory():
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

def clear_memory():
    torch.cuda.empty_cache()

def print_mem(mem_size, name):
    mem_gb = mem_size / 1024**3
    mem_mb = mem_size / 1024**2
    print(f"- {name}: {mem_gb:.2f} GB ({mem_mb:.2f} MB)")

def print_memory(message=None):
    """
    Print detailed GPU memory statistics for a specific device.
    
    Args:
        device_id (int): GPU device ID
    """
    allocated_mem = torch.cuda.memory_allocated(device)
    reserved_mem = torch.cuda.memory_reserved(device)
    mem_on_device = torch.cuda.mem_get_info(device)[0]
    peak_mem = torch.cuda.max_memory_allocated(device)

    print(f"\nGPU:{device} Memory Status {message}:")
    print_mem(allocated_mem, "allocated memory")
    print_mem(reserved_mem, "reserved memory")
    print_mem(peak_mem, "peak memory")
    print_mem(mem_on_device, "mem on device")

# function to make canny image (for controlnet)
def make_canny(image):
    image = np.array(image)
    image = cv2.Canny(image, 100, 200)
    image = image[:, :, None]
    image = np.concatenate([image, image, image], axis=2)
    return Image.fromarray(image)


# (1)Define inputs
# for text2img/img2img
prompt = "a photo of an astronaut riding a horse on mars"

# for img2img
url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png"
init_image = load_image(url).convert("RGB")
strength = 0.9 

# for controlnet
control_image = make_canny(init_image)
controlnet_conditioning_scale = 0.5  # recommended for good generalization
# for controlnet_union
processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
controlnet_union_image = processor(init_image, output_type="pil")

# for inpainting
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"

inpaint_image = load_image(img_url).resize((1024, 1024))
inpaint_mask = load_image(mask_url).resize((1024, 1024))
inpaint_control_image = make_canny(inpaint_image)
inpaint_strength = 0.99


# (2) define blocks and nodes(builder)   

all_blocks_map = AUTO_BLOCKS.copy()
# text block
text_block = all_blocks_map.pop("text_encoder")()
# image encoder block
image_encoder_block = all_blocks_map.pop("image_encoder")()
# decoder block
decoder_block = all_blocks_map.pop("decode")()

class SDXLAutoBlocks(SequentialPipelineBlocks):
    block_classes = list(all_blocks_map.values())
    block_names = list(all_blocks_map.keys())
# sdxl main block
sdxl_auto_blocks = SDXLAutoBlocks()

# lora step
lora_step = StableDiffusionXLLoraStep()


image2image_blocks_map = IMAGE2IMAGE_BLOCKS.copy()
# we do not need image_encoder for refiner becuase it takes image_latents (from another pipeline) as input
image_block = image2image_blocks_map.pop("image_encoder")()
# refiner block
class RefinerSteps(SequentialPipelineBlocks):
    block_classes = list(image2image_blocks_map.values())
    block_names = list(image2image_blocks_map.keys())
refiner_block = RefinerSteps()

text_node = ModularPipeline.from_block(text_block)
image_node = ModularPipeline.from_block(image_encoder_block)
sdxl_node = ModularPipeline.from_block(sdxl_auto_blocks)
decoder_node = ModularPipeline.from_block(decoder_block)
refiner_node = ModularPipeline.from_block(refiner_block)
lora_node = ModularPipeline.from_block(lora_step)


# (3) add states to nodes
repo = "stabilityai/stable-diffusion-xl-base-1.0"
refiner_repo = "stabilityai/stable-diffusion-xl-refiner-1.0"
controlnet_repo = "diffusers/controlnet-canny-sdxl-1.0"
inpaint_repo = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
vae_fix_repo = "madebyollin/sdxl-vae-fp16-fix"
controlnet_union_repo = "brad-twinkl/controlnet-union-sdxl-1.0-promax"


components = ComponentsManager()
components.add_from_pretrained(repo, torch_dtype=dtype)


controlnet = ControlNetModel.from_pretrained(controlnet_repo, torch_dtype=dtype)
refiner_unet = UNet2DConditionModel.from_pretrained(refiner_repo, subfolder="unet", torch_dtype=dtype)
inpaint_unet = UNet2DConditionModel.from_pretrained(inpaint_repo, subfolder="unet", torch_dtype=dtype)
vae_fix = AutoencoderKL.from_pretrained(vae_fix_repo, torch_dtype=dtype)
controlnet_union = ControlNetUnionModel.from_pretrained(controlnet_union_repo, torch_dtype=dtype)

components.add("controlnet", controlnet)
components.add("refiner_unet", refiner_unet)
components.add("inpaint_unet", inpaint_unet)
components.add("controlnet_union", controlnet_union)
components.add("vae_fix", vae_fix)


# you can add guiders to manager too but no need because it was not serialized
pag_guider = PAGGuider(pag_applied_layers="mid")
controlnet_pag_guider = PAGGuider(pag_applied_layers="mid")
cfg_guider = CFGGuider()
controlnet_cfg_guider = CFGGuider()

# load components/config into nodes
text_node.update_states(**components.get(["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"]))
image_node.update_states(**components.get(["vae"]))
decoder_node.update_states(vae=components.get("vae"))

sdxl_node.update_states(**components.get(["unet", "scheduler", "vae", "controlnet"]))
refiner_node.update_states(**components.get(["text_encoder_2","tokenizer_2", "vae", "scheduler"]), unet=components.get("refiner_unet"), force_zeros_for_empty_prompt=True, requires_aesthetics_score=True)

lora_node.update_states(**components.get(["unet", "text_encoder", "text_encoder_2"]))

# (4) enable auto cpu offload: automatically offload models when available gpu memory go below a certain threshold
components.enable_auto_cpu_offload(device=device)
print(components)
reset_memory()


# (5) run the workflows
print(f" ")
print(f" text_node:")
print(text_node)
print(f" ")
print(f" generating text embeddings with text_node")
# using text_node to generate text embeddings
text_state = text_node(prompt=prompt)
print(" ")
print(f" components info after run text_node: text_encoder and text_encoder_2 are on device")
print(components)
print(f" ")
print(f" text_state info")
print(text_state)
print(" ")



# using sdxl_node to generate images

# to get info about sdxl_node and how to use it: inputs/outputs/components
# this is an "auto" workflow that works for all use cases: text2img, img2img, inpainting, controlnet, etc.
# so the information might not be super useful for your specific use case, you will find a "trigger inputs" section says this

# Trigger Inputs: {'control_mode', 'control_image', 'image_latents', 'mask'}
#  Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('control_mode')`).
# Check `.doc` of returned object for more information. provided)

print(f" ")
print(f" sdxl_node:")
print(sdxl_node)
print(" ")

# since we want to use text2img use case, we can run the following to see components/blocks/inputs for this use case
print(f" ")
print(f" sdxl_node info (default use case: text2img)")
print(sdxl_node.get_execution_blocks())
print(" ")


# test1: text2img use case
# when you run the auto workflow, you will get these logs telling you which blocks are actuallyrunning
# (should match what the sdxl_node told you)
# Running block: StableDiffusionXLBeforeDenoiseStep, trigger: None
# Running block: StableDiffusionXLDenoiseStep, trigger: None
# Running block: StableDiffusionXLDecodeStep, trigger: None

generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator, 
    output="latents"
)
images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test1_out_text2img_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test1_out_text2img.png")

clear_memory()

# test2: text2img with lora use case
print(f" ")
print(f" running test2: text2img with lora use case")
generator = torch.Generator(device="cuda").manual_seed(0)
lora_node.load_lora_weights("rajkumaralma/dissolve_dust_style", weight_name="ral-dissolve-sdxl.safetensors", adapter_name="ral-dissolve")
latents = sdxl_node(
    **text_state.intermediates, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator, 
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test2_out_text2img_lora_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test2_out_text2img_lora.png")


# test3:text2image with pag
print(f" ")
print(f" running test3:text2image with pag")
if not test_lora:
    lora_node.unload_lora_weights()

sdxl_node.update_states(guider=pag_guider, controlnet_guider=controlnet_pag_guider)
generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator, 
    guider_kwargs={"pag_scale": 3.0},
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test3_out_text2img_pag_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test3_out_text2img_pag.png")

clear_memory()
# checkout the components if you want, the models used is moved to devicem some might get offloaded to cpu
# print(components)

# test4: SDXL(text2img) with controlnet

if not test_pag:
    sdxl_node.update_states(guider=cfg_guider, controlnet_guider=controlnet_cfg_guider)
    guider_kwargs = {}
else:
    guider_kwargs = {"pag_scale": 3.0}


# we are going to pass a new input now `control_image` so the workflow will be automatically converted to controlnet use case
# let's checkout the info for controlnet use case
print(f" sdxl_node info (controlnet use case)")
print(sdxl_node.get_execution_blocks("control_image"))
print(" ")

print(f" ")
print(f" running test4: SDXL(text2img) with controlnet")

generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    control_image=control_image, 
    controlnet_conditioning_scale=controlnet_conditioning_scale, 
    guider_kwargs=guider_kwargs, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator,
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test4_out_text2img_control_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test4_out_text2img_control.png")

clear_memory()

# test5: SDXL(img2img)

# for img2img use case, we encode the image with image_node first, this way we can use the same image_latents for different workflows
# let's checkout the image_node
print(f" image_node info")
print(image_node)
print(" ")


print(f" ")
print(f" running test5: SDXL(img2img)")

generator = torch.Generator(device="cuda").manual_seed(0)
image_state = image_node(image=init_image, generator=generator)

# let's checkout what's in image_state
print(f" image_state info")
print(image_state)
print(" ")

# let's checkout the sdxl_node info for img2img use case
print(f" sdxl_node info (img2img use case)")
print(sdxl_node.get_execution_blocks("image_latents"))
print(" ")

latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    strength=strength, 
    num_images_per_prompt=num_images_per_prompt,
    guider_kwargs=guider_kwargs, 
    generator=generator, 
    output="latents"
)
images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test5_out_img2img_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test5_out_img2img.png")

clear_memory()

# test6: SDXL(img2img) with controlnet
# let's checkout the sdxl_node info for img2img controlnet use case
print(f" sdxl_node info (img2img controlnet use case)")
print(sdxl_node.get_execution_blocks("image_latents","control_image"))
print(" ")

print(f" ")
print(f" running test6: SDXL(img2img) with controlnet")

generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    control_image=control_image, 
    guider_kwargs=guider_kwargs, 
    controlnet_conditioning_scale=controlnet_conditioning_scale, 
    strength=strength, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator, 
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test6_out_img2img_control_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test6_out_img2img_control.png")

clear_memory()

# test7: img2img with refiner

# let's checkout the refiner_node
print(f" refiner_node info")
print(refiner_node)
print(" ")

print(f" ")
print(f" running test7: img2img with refiner")

images_output = refiner_node(
    image_latents=latents, 
    prompt=prompt, 
    denoising_start=0.8, 
    generator=generator, 
    num_images_per_prompt=num_images_per_prompt,
    output="images"
)
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test7_out_img2img_refiner_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test7_out_img2img_refiner.png")

clear_memory()

# test8: SDXL(inpainting)
# let's checkout the sdxl_node info for inpainting use case
print(f" sdxl_node info (inpainting use case)")
print(sdxl_node.get_execution_blocks("mask", "image_latents"))
print(" ")

print(f" ") 
print(f" running test8: SDXL(inpainting)")

generator = torch.Generator(device="cuda").manual_seed(0)
image_state = image_node(image=inpaint_image, mask_image=inpaint_mask, height=1024, width=1024, generator=generator)
print(f" image_state info")
print(image_state)
print(" ")
latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    guider_kwargs=guider_kwargs, 
    num_images_per_prompt=num_images_per_prompt,
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    output="latents"
)
images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test8_out_inpainting_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test8_out_inpainting.png")

clear_memory()

# test9: SDXL(inpainting) with controlnet
# let's checkout the sdxl_node info for inpainting + controlnet use case
print(f" sdxl_node info (inpainting + controlnet use case)")
print(sdxl_node.get_execution_blocks("mask", "control_image"))
print(" ")

print(f" ") 
print(f" running test9: SDXL(inpainting) with controlnet")

generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    control_image=control_image, 
    guider_kwargs=guider_kwargs, 
    num_images_per_prompt=num_images_per_prompt,
    controlnet_conditioning_scale=controlnet_conditioning_scale, 
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    output="latents"
)
images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test9_out_inpainting_control_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test9_out_inpainting_control.png")

clear_memory()

# test10: SDXL(inpainting) with inpaint_unet
print(f" ") 
print(f" running test10: SDXL(inpainting) with inpaint_unet")

sdxl_node.update_states(unet=components.get("inpaint_unet"))
generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    guider_kwargs=guider_kwargs, 
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    num_images_per_prompt=num_images_per_prompt,
    output="latents"
)
images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test10_out_inpainting_inpaint_unet_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test10_out_inpainting_inpaint_unet.png")

clear_memory()


# test11: SDXL(inpainting) with inpaint_unet + padding_mask_crop
print(f" ") 
print(f" running test11: SDXL(inpainting) with inpaint_unet (padding_mask_crop=33)")

generator = torch.Generator(device="cuda").manual_seed(0)
image_state = image_node(image=inpaint_image, mask_image=inpaint_mask, height=1024, width=1024, generator=generator, padding_mask_crop=33)
print(f" image_state info")
print(image_state)
print(" ")

latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    num_images_per_prompt=num_images_per_prompt,
    guider_kwargs=guider_kwargs, 
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    output="latents"
)

# we need a different decoder when using padding_mask_crop
print(f" decoder_node info")
print(decoder_node)
print(" ")
print(f" decoder_node info (inpaint/padding_mask_crop)")
print(decoder_node.pipeline_block.blocks["inpaint"])
print(" ")

images_output = decoder_node(latents=latents, crops_coords=image_state.get_intermediate("crops_coords"), **image_state.inputs, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test11_out_inpainting_inpaint_unet_padding_mask_crop_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test11_out_inpainting_inpaint_unet_padding_mask_crop.png")

clear_memory()


# test12: apg

print(f" ")
print(f" running test12: apg")

apg_guider = APGGuider()
sdxl_node.update_states(guider=apg_guider, unet=components.get("unet"))


generator = torch.Generator().manual_seed(0)
latents= sdxl_node(
  **text_state.intermediates,
  generator=generator,
  num_inference_steps=20,
  num_images_per_prompt=1, # yiyi: apg does not work with num_images_per_prompt > 1
  guidance_scale=15,
  height=896,
  width=768,
  guider_kwargs={
      "adaptive_projected_guidance_momentum": -0.5,
      "adaptive_projected_guidance_rescale_factor": 15.0,
  },
  output="latents"
)


images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test12_out_apg_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test12_out_apg.png")

clear_memory()


# test13: SDXL(text2img) with controlnet_union

sdxl_node.update_states(controlnet=components.get("controlnet_union"), unet=components.get("unet"), guider=pag_guider, controlnet_guider=controlnet_pag_guider)
image_node.update_states(vae=components.get("vae_fix"))
decoder_node.update_states(vae=components.get("vae_fix"))

# we are going to pass a new input now `control_mode` so the workflow will be automatically converted to controlnet use case
# let's checkout the info for controlnet use case
print(f" sdxl_node info (controlnet union use case)")
print(sdxl_node.get_execution_blocks("control_mode"))
print(" ")

print(f" ")
print(f" running test13: SDXL(text2img) with controlnet_union")

generator = torch.Generator(device="cuda").manual_seed(0)

latents = sdxl_node(
    **text_state.intermediates, 
    control_mode=[3],
    control_image=[controlnet_union_image], 
    num_images_per_prompt=num_images_per_prompt,
    height=1024,
    width=1024,
    generator=generator,
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test13_out_text2img_control_union_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test13_out_text2img_control_union.png")

clear_memory()


# test14: SDXL(img2img) with controlnet_union
print(f" image_node info(with vae_fix for controlnet union)")
print(image_node)
print(" ")


print(f" ")
print(f" sdxl_node info (img2img controlnet union use case)")
print(sdxl_node.get_execution_blocks("image_latents", "control_mode"))
print(" ")

print(f" ")
print(f" running test14: SDXL(img2img) with controlnet_union")

generator = torch.Generator(device="cuda").manual_seed(0)
image_state = image_node(image=init_image, generator=generator)

print(f" image_state info")
print(image_state)
print(" ")

latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    control_mode=[3],
    control_image=[controlnet_union_image], 
    num_images_per_prompt=num_images_per_prompt,
    height=1024,
    width=1024,
    generator=generator,
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test14_out_img2img_control_union_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test14_out_img2img_control_union.png")

clear_memory()

# test15: SDXL(inpainting) with controlnet_union
print(f" ")
print(f" sdxl_node info (inpainting controlnet union use case)")
print(sdxl_node.get_execution_blocks("mask", "control_mode"))
print(" ")

print(f" ")
print(f" running test15: SDXL(inpainting) with controlnet_union")

generator = torch.Generator(device="cuda").manual_seed(0)
image_state = image_node(image=inpaint_image, mask_image=inpaint_mask, height=1024, width=1024, generator=generator)

print(f" image_state info")
print(image_state)
print(" ")

latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    control_mode=[3],
    control_image=[controlnet_union_image], 
    num_images_per_prompt=num_images_per_prompt,
    height=1024,
    width=1024,
    generator=generator,
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test15_out_inpainting_control_union_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test15_out_inpainting_control_union.png")

clear_memory()

print_memory("the end")

print(f" components info after the end")
print(components)
test3: modular setup with IPAdapter
import os
import torch
import numpy as np
import cv2
from PIL import Image

from diffusers import (
    ControlNetModel,
    ModularPipeline,
    UNet2DConditionModel,
    AutoencoderKL,
    ControlNetUnionModel,
)
from diffusers.utils import load_image
from diffusers.guider import PAGGuider, CFGGuider, APGGuider
from diffusers.pipelines.modular_pipeline import SequentialPipelineBlocks
from diffusers.pipelines.components_manager import ComponentsManager
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_modular import AUTO_BLOCKS, IMAGE2IMAGE_BLOCKS, StableDiffusionXLLoraStep, StableDiffusionXLIPAdapterStep

from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor

from controlnet_aux import LineartAnimeDetector

import logging
logging.getLogger().setLevel(logging.INFO)
logging.getLogger("diffusers").setLevel(logging.INFO)


# define device and dtype
device = "cuda:0"
dtype = torch.float16
num_images_per_prompt = 1

test_pag = True
test_lora = False


# define output folder
out_folder = "modular_test_outputs_0121_ipa"
if os.path.exists(out_folder):
    # Remove all files in the directory
    for file in os.listdir(out_folder):
        file_path = os.path.join(out_folder, file)
        try:
            if os.path.isfile(file_path):
                os.unlink(file_path)
        except Exception as e:
            print(f"Error: {e}")
os.makedirs(out_folder, exist_ok=True)

# functions for memory info
def reset_memory():
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

def clear_memory():
    torch.cuda.empty_cache()

def print_mem(mem_size, name):
    mem_gb = mem_size / 1024**3
    mem_mb = mem_size / 1024**2
    print(f"- {name}: {mem_gb:.2f} GB ({mem_mb:.2f} MB)")

def print_memory(message=None):
    """
    Print detailed GPU memory statistics for a specific device.
    
    Args:
        device_id (int): GPU device ID
    """
    allocated_mem = torch.cuda.memory_allocated(device)
    reserved_mem = torch.cuda.memory_reserved(device)
    mem_on_device = torch.cuda.mem_get_info(device)[0]
    peak_mem = torch.cuda.max_memory_allocated(device)

    print(f"\nGPU:{device} Memory Status {message}:")
    print_mem(allocated_mem, "allocated memory")
    print_mem(reserved_mem, "reserved memory")
    print_mem(peak_mem, "peak memory")
    print_mem(mem_on_device, "mem on device")

# function to make canny image (for controlnet)
def make_canny(image):
    image = np.array(image)
    image = cv2.Canny(image, 100, 200)
    image = image[:, :, None]
    image = np.concatenate([image, image, image], axis=2)
    return Image.fromarray(image)


# (1)Define inputs
# for text2img/img2img
prompt = "a bear sitting in a chair drinking a milkshake"
negative_prompt = "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality"

# for img2img
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
init_image = load_image(url).convert("RGB")
strength = 0.9 

# for controlnet
control_image = make_canny(init_image)
controlnet_conditioning_scale = 0.5  # recommended for good generalization
# for controlnet_union
processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
controlnet_union_image = processor(init_image, output_type="pil")

# for inpainting
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"

inpaint_image = load_image(img_url).resize((1024, 1024))
inpaint_mask = load_image(mask_url).resize((1024, 1024))
inpaint_control_image = make_canny(inpaint_image)
inpaint_strength = 0.99

# for ip adapter
ip_adapter_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_diner.png")


# (2) define blocks and nodes(builder)   

all_blocks_map = AUTO_BLOCKS.copy()
# text block
text_block = all_blocks_map.pop("text_encoder")()
# image encoder block
image_encoder_block = all_blocks_map.pop("image_encoder")()
# decoder block
decoder_block = all_blocks_map.pop("decode")()

class SDXLAutoBlocks(SequentialPipelineBlocks):
    block_classes = list(all_blocks_map.values())
    block_names = list(all_blocks_map.keys())
# sdxl main block
sdxl_auto_blocks = SDXLAutoBlocks()

# lora step
lora_step = StableDiffusionXLLoraStep()

# ip adapter step
ip_adapter_step = StableDiffusionXLIPAdapterStep()


image2image_blocks_map = IMAGE2IMAGE_BLOCKS.copy()
# we do not need image_encoder for refiner becuase it takes image_latents (from another pipeline) as input
image_block = image2image_blocks_map.pop("image_encoder")()
# refiner block
class RefinerSteps(SequentialPipelineBlocks):
    block_classes = list(image2image_blocks_map.values())
    block_names = list(image2image_blocks_map.keys())
refiner_block = RefinerSteps()

text_node = ModularPipeline.from_block(text_block)
image_node = ModularPipeline.from_block(image_encoder_block)
sdxl_node = ModularPipeline.from_block(sdxl_auto_blocks)
decoder_node = ModularPipeline.from_block(decoder_block)
refiner_node = ModularPipeline.from_block(refiner_block)
lora_node = ModularPipeline.from_block(lora_step)
ip_adapter_node = ModularPipeline.from_block(ip_adapter_step)


# (3) add states to nodes
repo = "stabilityai/stable-diffusion-xl-base-1.0"
refiner_repo = "stabilityai/stable-diffusion-xl-refiner-1.0"
controlnet_repo = "diffusers/controlnet-canny-sdxl-1.0"
inpaint_repo = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
vae_fix_repo = "madebyollin/sdxl-vae-fp16-fix"
controlnet_union_repo = "brad-twinkl/controlnet-union-sdxl-1.0-promax"
ip_adapter_repo = "h94/IP-Adapter"


components = ComponentsManager()
components.add_from_pretrained(repo, torch_dtype=dtype)


controlnet = ControlNetModel.from_pretrained(controlnet_repo, torch_dtype=dtype)
refiner_unet = UNet2DConditionModel.from_pretrained(refiner_repo, subfolder="unet", torch_dtype=dtype)
inpaint_unet = UNet2DConditionModel.from_pretrained(inpaint_repo, subfolder="unet", torch_dtype=dtype)
vae_fix = AutoencoderKL.from_pretrained(vae_fix_repo, torch_dtype=dtype)
controlnet_union = ControlNetUnionModel.from_pretrained(controlnet_union_repo, torch_dtype=dtype)

components.add("controlnet", controlnet)
components.add("refiner_unet", refiner_unet)
components.add("inpaint_unet", inpaint_unet)
components.add("controlnet_union", controlnet_union)
components.add("vae_fix", vae_fix)

image_encoder = CLIPVisionModelWithProjection.from_pretrained(ip_adapter_repo, subfolder="sdxl_models/image_encoder", torch_dtype=dtype)
feature_extractor = CLIPImageProcessor(size=224, crop_size=224)

components.add("image_encoder", image_encoder)
components.add("feature_extractor", feature_extractor)

# you can add guiders to manager too but no need because it was not serialized
pag_guider = PAGGuider(pag_applied_layers="mid")
controlnet_pag_guider = PAGGuider(pag_applied_layers="mid")
cfg_guider = CFGGuider()
controlnet_cfg_guider = CFGGuider()

# load components/config into nodes
text_node.update_states(**components.get(["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"]))
image_node.update_states(**components.get(["vae"]))
decoder_node.update_states(vae=components.get("vae"))

sdxl_node.update_states(**components.get(["unet", "scheduler", "vae", "controlnet"]))
refiner_node.update_states(**components.get(["text_encoder_2","tokenizer_2", "vae", "scheduler"]), unet=components.get("refiner_unet"), force_zeros_for_empty_prompt=True, requires_aesthetics_score=True)

lora_node.update_states(**components.get(["unet", "text_encoder", "text_encoder_2"]))
ip_adapter_node.update_states(**components.get(["unet", "image_encoder", "feature_extractor"]))

# (4) enable auto cpu offload: automatically offload models when available gpu memory go below a certain threshold
components.enable_auto_cpu_offload(device=device)
print(components)
reset_memory()


# (5) run the workflows
print(f" ")
print(f" text_node:")
print(text_node)
print(f" ")
print(f" generating text embeddings with text_node")
# using text_node to generate text embeddings
text_state = text_node(prompt=prompt, negative_prompt=negative_prompt)
print(" ")
print(f" components info after run text_node: text_encoder and text_encoder_2 are on device")
print(components)
print(f" ")
print(f" text_state info")
print(text_state)
print(" ")


# use ip adapter to get image embeddings
print(f" ")
print(f" ip_adapter_node:")
print(ip_adapter_node)
print(f" ")
print(f" generating ip adapter image embeddings with ip_adapter_node")
ip_adapter_node.load_ip_adapter(ip_adapter_repo, subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
ip_adapter_node.set_ip_adapter_scale(0.6)
ip_adapter_state = ip_adapter_node(ip_adapter_image=ip_adapter_image)
print(f" ")
print(f" ip_adapter_state info")
print(ip_adapter_state)
print(" ")


# using sdxl_node to generate images
print(f" ")
print(f" sdxl_node:")
print(sdxl_node)
print(" ")

# since we want to use text2img use case, we can run the following to see components/blocks/inputs for this use case
print(f" ")
print(f" sdxl_node info (default use case: text2img)")
print(sdxl_node.get_execution_blocks())
print(" ")


# test1: text2img use case
# when you run the auto workflow, you will get these logs telling you which blocks are actuallyrunning
# (should match what the sdxl_node told you)
# Running block: StableDiffusionXLBeforeDenoiseStep, trigger: None
# Running block: StableDiffusionXLDenoiseStep, trigger: None
# Running block: StableDiffusionXLDecodeStep, trigger: None

generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    num_images_per_prompt=num_images_per_prompt,
    **ip_adapter_state.intermediates,
    generator=generator, 
    output="latents"
)
images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test1_out_text2img_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test1_out_text2img.png")

clear_memory()

# test2: text2img with lora use case
print(f" ")
print(f" running test2: text2img with lora use case")
generator = torch.Generator(device="cuda").manual_seed(0)
lora_node.load_lora_weights("rajkumaralma/dissolve_dust_style", weight_name="ral-dissolve-sdxl.safetensors", adapter_name="ral-dissolve")
latents = sdxl_node(
    **text_state.intermediates, 
    num_images_per_prompt=num_images_per_prompt,
    **ip_adapter_state.intermediates,
    generator=generator, 
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test2_out_text2img_lora_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test2_out_text2img_lora.png")


# test3:text2image with pag
print(f" ")
print(f" running test3:text2image with pag")
if not test_lora:
    lora_node.unload_lora_weights()

sdxl_node.update_states(guider=pag_guider, controlnet_guider=controlnet_pag_guider)
generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    num_images_per_prompt=num_images_per_prompt,
    **ip_adapter_state.intermediates,
    generator=generator, 
    guider_kwargs={"pag_scale": 3.0},
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test3_out_text2img_pag_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test3_out_text2img_pag.png")

clear_memory()
# checkout the components if you want, the models used is moved to devicem some might get offloaded to cpu
# print(components)

# test4: SDXL(text2img) with controlnet

if not test_pag:
    sdxl_node.update_states(guider=cfg_guider, controlnet_guider=controlnet_cfg_guider)
    guider_kwargs = {}
else:
    guider_kwargs = {"pag_scale": 3.0}


# we are going to pass a new input now `control_image` so the workflow will be automatically converted to controlnet use case
# let's checkout the info for controlnet use case
print(f" sdxl_node info (controlnet use case)")
print(sdxl_node.get_execution_blocks("control_image"))
print(" ")

print(f" ")
print(f" running test4: SDXL(text2img) with controlnet")

generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    control_image=control_image, 
    controlnet_conditioning_scale=controlnet_conditioning_scale, 
    guider_kwargs=guider_kwargs, 
    num_images_per_prompt=num_images_per_prompt,
    **ip_adapter_state.intermediates,
    generator=generator,
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test4_out_text2img_control_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test4_out_text2img_control.png")

clear_memory()

# test5: SDXL(img2img)

# for img2img use case, we encode the image with image_node first, this way we can use the same image_latents for different workflows
# let's checkout the image_node
print(f" image_node info")
print(image_node)
print(" ")


print(f" ")
print(f" running test5: SDXL(img2img)")

generator = torch.Generator(device="cuda").manual_seed(0)
image_state = image_node(image=init_image, generator=generator)

# let's checkout what's in image_state
print(f" image_state info")
print(image_state)
print(" ")

# let's checkout the sdxl_node info for img2img use case
print(f" sdxl_node info (img2img use case)")
print(sdxl_node.get_execution_blocks("image_latents"))
print(" ")

latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    **ip_adapter_state.intermediates,
    strength=strength, 
    num_images_per_prompt=num_images_per_prompt,
    guider_kwargs=guider_kwargs, 
    generator=generator, 
    output="latents"
)
images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test5_out_img2img_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test5_out_img2img.png")

clear_memory()

# test6: SDXL(img2img) with controlnet
# let's checkout the sdxl_node info for img2img controlnet use case
print(f" sdxl_node info (img2img controlnet use case)")
print(sdxl_node.get_execution_blocks("image_latents","control_image"))
print(" ")

print(f" ")
print(f" running test6: SDXL(img2img) with controlnet")

generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    **ip_adapter_state.intermediates,
    control_image=control_image, 
    guider_kwargs=guider_kwargs, 
    controlnet_conditioning_scale=controlnet_conditioning_scale, 
    strength=strength, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator, 
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test6_out_img2img_control_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test6_out_img2img_control.png")

clear_memory()

# test7: img2img with refiner

# let's checkout the refiner_node
print(f" refiner_node info")
print(refiner_node)
print(" ")

print(f" ")
print(f" running test7: img2img with refiner")

images_output = refiner_node(
    image_latents=latents, 
    prompt=prompt, 
    denoising_start=0.8, 
    generator=generator, 
    num_images_per_prompt=num_images_per_prompt,
    output="images"
)
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test7_out_img2img_refiner_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test7_out_img2img_refiner.png")

clear_memory()

# test8: SDXL(inpainting)
# let's checkout the sdxl_node info for inpainting use case
print(f" sdxl_node info (inpainting use case)")
print(sdxl_node.get_execution_blocks("mask", "image_latents"))
print(" ")

print(f" ") 
print(f" running test8: SDXL(inpainting)")

generator = torch.Generator(device="cuda").manual_seed(0)
image_state = image_node(image=inpaint_image, mask_image=inpaint_mask, height=1024, width=1024, generator=generator)
print(f" image_state info")
print(image_state)
print(" ")
latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    **ip_adapter_state.intermediates,
    guider_kwargs=guider_kwargs, 
    num_images_per_prompt=num_images_per_prompt,
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    output="latents"
)
images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test8_out_inpainting_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test8_out_inpainting.png")

clear_memory()

# test9: SDXL(inpainting) with controlnet
# let's checkout the sdxl_node info for inpainting + controlnet use case
print(f" sdxl_node info (inpainting + controlnet use case)")
print(sdxl_node.get_execution_blocks("mask", "control_image"))
print(" ")

print(f" ") 
print(f" running test9: SDXL(inpainting) with controlnet")

generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    **ip_adapter_state.intermediates,
    control_image=control_image, 
    guider_kwargs=guider_kwargs, 
    num_images_per_prompt=num_images_per_prompt,
    controlnet_conditioning_scale=controlnet_conditioning_scale, 
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    output="latents"
)
images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test9_out_inpainting_control_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test9_out_inpainting_control.png")

clear_memory()

# test10: SDXL(inpainting) with inpaint_unet
print(f" ") 
print(f" running test10: SDXL(inpainting) with inpaint_unet")

sdxl_node.update_states(unet=components.get("inpaint_unet"))
generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    **ip_adapter_state.intermediates,
    guider_kwargs=guider_kwargs, 
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    num_images_per_prompt=num_images_per_prompt,
    output="latents"
)
images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test10_out_inpainting_inpaint_unet_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test10_out_inpainting_inpaint_unet.png")

clear_memory()


# test11: SDXL(inpainting) with inpaint_unet + padding_mask_crop
print(f" ") 
print(f" running test11: SDXL(inpainting) with inpaint_unet (padding_mask_crop=33)")

generator = torch.Generator(device="cuda").manual_seed(0)
image_state = image_node(image=inpaint_image, mask_image=inpaint_mask, height=1024, width=1024, generator=generator, padding_mask_crop=33)
print(f" image_state info")
print(image_state)
print(" ")

latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    **ip_adapter_state.intermediates,
    num_images_per_prompt=num_images_per_prompt,
    guider_kwargs=guider_kwargs, 
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    output="latents"
)

# we need a different decoder when using padding_mask_crop
print(f" decoder_node info")
print(decoder_node)
print(" ")
print(f" decoder_node info (inpaint/padding_mask_crop)")
print(decoder_node.pipeline_block.blocks["inpaint"])
print(" ")

images_output = decoder_node(latents=latents, crops_coords=image_state.get_intermediate("crops_coords"), **image_state.inputs, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test11_out_inpainting_inpaint_unet_padding_mask_crop_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test11_out_inpainting_inpaint_unet_padding_mask_crop.png")

clear_memory()


# test12: apg

print(f" ")
print(f" running test12: apg")

apg_guider = APGGuider()
sdxl_node.update_states(guider=apg_guider, unet=components.get("unet"))


generator = torch.Generator().manual_seed(0)
latents= sdxl_node(
  **text_state.intermediates,
  generator=generator,
  num_inference_steps=20,
  num_images_per_prompt=1, # yiyi: apg does not work with num_images_per_prompt > 1
  guidance_scale=15,
  **ip_adapter_state.intermediates,
  height=896,
  width=768,
  guider_kwargs={
      "adaptive_projected_guidance_momentum": -0.5,
      "adaptive_projected_guidance_rescale_factor": 15.0,
  },
  output="latents"
)


images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test12_out_apg_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test12_out_apg.png")

clear_memory()


# test13: SDXL(text2img) with controlnet_union

sdxl_node.update_states(controlnet=components.get("controlnet_union"), unet=components.get("unet"), guider=pag_guider, controlnet_guider=controlnet_pag_guider)
image_node.update_states(vae=components.get("vae_fix"))
decoder_node.update_states(vae=components.get("vae_fix"))

# we are going to pass a new input now `control_mode` so the workflow will be automatically converted to controlnet use case
# let's checkout the info for controlnet use case
print(f" sdxl_node info (controlnet union use case)")
print(sdxl_node.get_execution_blocks("control_mode"))
print(" ")

print(f" ")
print(f" running test13: SDXL(text2img) with controlnet_union")

generator = torch.Generator(device="cuda").manual_seed(0)

latents = sdxl_node(
    **text_state.intermediates, 
    **ip_adapter_state.intermediates,
    control_mode=[3],
    control_image=[controlnet_union_image], 
    num_images_per_prompt=num_images_per_prompt,
    height=1024,
    width=1024,
    generator=generator,
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test13_out_text2img_control_union_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test13_out_text2img_control_union.png")

clear_memory()


# test14: SDXL(img2img) with controlnet_union
print(f" image_node info(with vae_fix for controlnet union)")
print(image_node)
print(" ")


print(f" ")
print(f" sdxl_node info (img2img controlnet union use case)")
print(sdxl_node.get_execution_blocks("image_latents", "control_mode"))
print(" ")

print(f" ")
print(f" running test14: SDXL(img2img) with controlnet_union")

generator = torch.Generator(device="cuda").manual_seed(0)
image_state = image_node(image=init_image, generator=generator)

print(f" image_state info")
print(image_state)
print(" ")

latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    **ip_adapter_state.intermediates,
    control_mode=[3],
    control_image=[controlnet_union_image], 
    num_images_per_prompt=num_images_per_prompt,
    height=1024,
    width=1024,
    generator=generator,
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test14_out_img2img_control_union_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test14_out_img2img_control_union.png")

clear_memory()

# test15: SDXL(inpainting) with controlnet_union
print(f" ")
print(f" sdxl_node info (inpainting controlnet union use case)")
print(sdxl_node.get_execution_blocks("mask", "control_mode"))
print(" ")

print(f" ")
print(f" running test15: SDXL(inpainting) with controlnet_union")

generator = torch.Generator(device="cuda").manual_seed(0)
image_state = image_node(image=inpaint_image, mask_image=inpaint_mask, height=1024, width=1024, generator=generator)

print(f" image_state info")
print(image_state)
print(" ")

latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    **ip_adapter_state.intermediates,
    control_mode=[3],
    control_image=[controlnet_union_image], 
    num_images_per_prompt=num_images_per_prompt,
    height=1024,
    width=1024,
    generator=generator,
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test15_out_inpainting_control_union_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test15_out_inpainting_control_union.png")

clear_memory()

print_memory("the end")

print(f" components info after the end")
print(components)

Developer Guide: Building with Modular Diffusers

Core Components Overview

The Modular Diffusers architecture consists of four main components:

ModularPipeline

The main interface for creating and running modular pipelines. Unlike traditional pipelines, you don't write it from scratch - it builds itself from pipeline blocks! Example usage:

from diffusers import ModularPipeline
pipe = ModularPipeline.from_block(auto_pipeline_block)
images = pipe(prompt="a cat", num_inference_steps=15, output="images")

PipelineBlock

The fundamental building block, similar to a mellon/comfy node. Each block:

  • Defines required components, inputs, and outputs
  • Implements __call__(pipeline, state) -> (pipeline, state)
  • Can be reused across different pipelines
  • Can be combined with other blocks

MultiPipelineBlocks

Combines multiple blocks into a bigger one! These combined blocks behave just like single blocks - with their own inputs, outputs, and components, but they are able to handle more complex workflows!

We have two types of MultiPipelineBlocks available, you can use them to combine individual blocks into ready-to-use sets (Like LEGO® presets!)

  1. SequentialPipelineBlocks

    • Chains blocks in sequential order
    class StableDiffusionXLMainSteps(SequentialPipelineBlocks):
        block_classes = [InputStep, SetTimestepsStep, ...]
        block_names = ["input", "set_timesteps", ...]
    
  2. AutoPipelineBlocks

    • Provides conditional block selection, AutoPipelineBlocks makes the complex if.. else.. logic in your code disappear! with this, you can write blocks for specific use case to keep your code path clean; and use AutoPipelineBlocks to combine blocks into convenient presets that can provide a better user experience :)
    • In this example the ControlNetDenoiseStep step will be dispatched when "control_image" is passed from the user, otherwise, it will run the default DenoseStep
    class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks):
        block_classes = [ ControlNetDenoiseStep, DenoiseStep]
        block_names = [ "controlnet", "unet"]
        block_trigger_inputs = ["control_image", None]
    

PipelineState and BlockStates

PipelineState and BlockStates manage dataflow between/inside blocks; they make debugging really easy! feel free to print out them at any given time to have an overview of all the shapes/types/values of your pipeline/block states

yiyixuxu avatar Oct 14 '24 19:10 yiyixuxu

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Very cool!

yoland68 avatar Oct 16 '24 17:10 yoland68

hi this is very interesting! I'm making a Python pipeline flow visual scripting tool, that can auto-convert functions to visual nodes for fast and modular UI blocks demo. Itself is a pip package: https://pypi.org/project/nozyio/

I wanted to integrate diffusers with my flow nodes UI project but found its not very modular. But this PR may change that! Looking forward to see how this evolves.

github: https://github.com/oozzy77/nozyio happy to connect!

oozzy77 avatar Oct 30 '24 14:10 oozzy77

@oozzy77 thanks! do you want to join a slack channel with me? if you want to experiment building something with this PR I'm eager to hear your feedback and iterate base on that

yiyixuxu avatar Oct 30 '24 20:10 yiyixuxu

Hi super willing to join slack channel with you! What’s the workspace channel I should join?or you can invite me @.***

On Thu, Oct 31, 2024 at 4:59 AM YiYi Xu @.***> wrote:

@oozzy77 https://github.com/oozzy77 thanks! do you want to join a slack channel with me? if you want to experiment building something with this PR I'm eager to hear your feedback and iterate base on that

— Reply to this email directly, view it on GitHub https://github.com/huggingface/diffusers/pull/9672#issuecomment-2448368081, or unsubscribe https://github.com/notifications/unsubscribe-auth/BMBK3ZHNSKN56N262LBH3WLZ6FCBNAVCNFSM6AAAAABP5SYMXOVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDINBYGM3DQMBYGE . You are receiving this because you were mentioned.Message ID: @.***>

oozzy77 avatar Oct 31 '24 20:10 oozzy77

@oozzy77 I sent an invite!

yiyixuxu avatar Oct 31 '24 23:10 yiyixuxu

Basic version of #10112 Detail-Daemon, original has more advanced calculation for adjustment

import torch
from diffusers import StableDiffusionXLModularPipeline
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_modular import (
  StableDiffusionXLTextEncoderStep,
  StableDiffusionXLDecodeLatentsStep,
  StableDiffusionXLInputStep,
  StableDiffusionXLAutoSetTimestepsStep,
  StableDiffusionXLAutoPrepareLatentsStep,
  StableDiffusionXLAutoPrepareAdditionalConditioningStep,
)
from diffusers.pipelines.modular_pipeline_builder import SequentialPipelineBlocks, PipelineState, PipelineBlock
from diffusers.guider import CFGGuider
from typing import Any, List, Tuple


class StableDiffusionXLDetailDenoiseStep(PipelineBlock):
    expected_components = ["unet", "scheduler", "guider"]

    @property
    def inputs(self) -> List[Tuple[str, Any]]:
        return [
            ("guidance_scale", 5.0),
            ("guidance_rescale", 0.0),
            ("cross_attention_kwargs", None),
            ("generator", None),
            ("eta", 0.0),
            ("guider_kwargs", None),
            ("detail_adjustment", 0.0),
        ]

    @property
    def intermediates_inputs(self) -> List[str]:
        return [
            "latents",
            "timesteps",
            "num_inference_steps",
            "pooled_prompt_embeds",
            "negative_pooled_prompt_embeds",
            "add_time_ids",
            "negative_add_time_ids",
            "timestep_cond",
            "prompt_embeds",
            "negative_prompt_embeds",
        ]

    @property
    def intermediates_outputs(self) -> List[str]:
        return ["latents"]

    def __init__(self, unet=None, scheduler=None, guider=None):
        if guider is None:
            guider = CFGGuider()
        super().__init__(unet=unet, scheduler=scheduler, guider=guider)

    @torch.no_grad()
    def __call__(self, pipeline, state: PipelineState) -> PipelineState:
        detail_adjustment = state.get_input("detail_adjustment")

        guidance_scale = state.get_input("guidance_scale")
        guidance_rescale = state.get_input("guidance_rescale")

        cross_attention_kwargs = state.get_input("cross_attention_kwargs")
        generator = state.get_input("generator")
        eta = state.get_input("eta")
        guider_kwargs = state.get_input("guider_kwargs")

        batch_size = state.get_intermediate("batch_size")
        prompt_embeds = state.get_intermediate("prompt_embeds")
        negative_prompt_embeds = state.get_intermediate("negative_prompt_embeds")
        pooled_prompt_embeds = state.get_intermediate("pooled_prompt_embeds")
        negative_pooled_prompt_embeds = state.get_intermediate("negative_pooled_prompt_embeds")
        add_time_ids = state.get_intermediate("add_time_ids")
        negative_add_time_ids = state.get_intermediate("negative_add_time_ids")

        timestep_cond = state.get_intermediate("timestep_cond")
        latents = state.get_intermediate("latents")

        timesteps = state.get_intermediate("timesteps")
        num_inference_steps = state.get_intermediate("num_inference_steps")
        disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False

        # adding default guider arguments: do_classifier_free_guidance, guidance_scale, guidance_rescale
        guider_kwargs = guider_kwargs or {}
        guider_kwargs = {
            **guider_kwargs,
            "disable_guidance": disable_guidance,
            "guidance_scale": guidance_scale,
            "guidance_rescale": guidance_rescale,
            "batch_size": batch_size,
        }

        pipeline.guider.set_guider(pipeline, guider_kwargs)
        # Prepare conditional inputs using the guider
        prompt_embeds = pipeline.guider.prepare_input(
            prompt_embeds,
            negative_prompt_embeds,
        )
        add_time_ids = pipeline.guider.prepare_input(
            add_time_ids,
            negative_add_time_ids,
        )
        pooled_prompt_embeds = pipeline.guider.prepare_input(
            pooled_prompt_embeds,
            negative_pooled_prompt_embeds,
        )

        added_cond_kwargs = {
            "text_embeds": pooled_prompt_embeds,
            "time_ids": add_time_ids,
        }

        # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
        extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta)
        num_warmup_steps = max(len(timesteps) - num_inference_steps * pipeline.scheduler.order, 0)

        with pipeline.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):
                # expand the latents if we are doing classifier free guidance
                latent_model_input = pipeline.guider.prepare_input(latents, latents)
                latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t)
                # predict the noise residual
                noise_pred = pipeline.unet(
                    latent_model_input,
                    t * (1.0 + detail_adjustment),
                    encoder_hidden_states=prompt_embeds,
                    timestep_cond=timestep_cond,
                    cross_attention_kwargs=cross_attention_kwargs,
                    added_cond_kwargs=added_cond_kwargs,
                    return_dict=False,
                )[0]
                # perform guidance
                noise_pred = pipeline.guider.apply_guidance(
                    noise_pred,
                    timestep=t,
                    latents=latents,
                )
                # compute the previous noisy sample x_t -> x_t-1
                latents_dtype = latents.dtype
                latents = pipeline.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
                if latents.dtype != latents_dtype:
                    if torch.backends.mps.is_available():
                        # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
                        latents = latents.to(latents_dtype)

                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0):
                    progress_bar.update()

        pipeline.guider.reset_guider(pipeline)
        state.add_intermediate("latents", latents)

        return pipeline, state


prompt = "A 4k dslr photo of a raccoon wearing an astronaut helmet, photorealistic."


class StableDiffusionXLMainSteps(SequentialPipelineBlocks):
  block_classes = [
      StableDiffusionXLInputStep,
      StableDiffusionXLAutoSetTimestepsStep,
      StableDiffusionXLAutoPrepareLatentsStep,
      StableDiffusionXLAutoPrepareAdditionalConditioningStep,
      StableDiffusionXLDetailDenoiseStep,
  ]
  block_prefixes = [
      "input",
      "set_timesteps",
      "prepare_latents",
      "prepare_add_cond",
      "denoise",
  ]

text_encoder_workflow = StableDiffusionXLTextEncoderStep()
decoder_workflow = StableDiffusionXLDecodeLatentsStep()
sdxl_workflow = StableDiffusionXLMainSteps()

repo = "stabilityai/stable-diffusion-xl-base-1.0"

text_encoder_workflow.add_states_from_pretrained(repo, torch_dtype=torch.float32)
decoder_workflow.add_states_from_pretrained(repo, torch_dtype=torch.float32)
sdxl_workflow.add_states_from_pretrained(repo, torch_dtype=torch.float32)

print(f" text encoder workflow: {text_encoder_workflow}")
print(f" decoder workflow: {decoder_workflow}")
print(f" main sdxl workflow: {sdxl_workflow}")
print(f" prepare_latents: {sdxl_workflow.blocks['prepare_latents_step']}")

text_node = StableDiffusionXLModularPipeline()
text_node.add_blocks(text_encoder_workflow)

sdxl_node = StableDiffusionXLModularPipeline()
sdxl_node.add_blocks(sdxl_workflow)

decoder_node = StableDiffusionXLModularPipeline()
decoder_node.add_blocks(decoder_workflow)

text_state = text_node.run_blocks(prompt=prompt)

print(f" text state: {text_state}")

generator = torch.Generator().manual_seed(0)
latents_state = sdxl_node.run_blocks(
  **text_state.intermediates,
  generator=generator,
  num_inference_steps=20,
  height=896,
  width=768,
  detail_adjustment=-0.05,
)
latents = latents_state.get_intermediate("latents")

image_state = decoder_node.run_blocks(latents=latents)
image_state.get_output("images").images[0].save("detail.png")

Original original

Detail detail

Diff compared to StableDiffusionXLDenoiseStep

@@ -807,7 +807,7 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
         return pipeline, state
 
 
-class StableDiffusionXLDenoiseStep(PipelineBlock):
+class StableDiffusionXLDetailDenoiseStep(PipelineBlock):
     expected_components = ["unet", "scheduler", "guider"]
 
     @property
@@ -819,6 +819,7 @@ class StableDiffusionXLDenoiseStep(PipelineBlock):
             ("generator", None),
             ("eta", 0.0),
             ("guider_kwargs", None),
+            ("detail_adjustment", 0.0),
         ]
 
     @property
@@ -847,6 +848,8 @@ class StableDiffusionXLDenoiseStep(PipelineBlock):
 
     @torch.no_grad()
     def __call__(self, pipeline, state: PipelineState) -> PipelineState:
+        detail_adjustment = state.get_input("detail_adjustment")
+
         guidance_scale = state.get_input("guidance_scale")
         guidance_rescale = state.get_input("guidance_rescale")
 
@@ -912,7 +915,7 @@ class StableDiffusionXLDenoiseStep(PipelineBlock):
                 # predict the noise residual
                 noise_pred = pipeline.unet(
                     latent_model_input,
-                    t,
+                    t * (1.0 + detail_adjustment),
                     encoder_hidden_states=prompt_embeds,
                     timestep_cond=timestep_cond,
                     cross_attention_kwargs=cross_attention_kwargs,

hlky avatar Dec 11 '24 16:12 hlky

very nice! but I think for this basic use case, we can just make a custom sigma outside the pipeline and pass to the pipeline https://github.com/huggingface/diffusers/pull/9672 - this works well with your refactor to separate out the schedule, no?

yiyixuxu avatar Dec 11 '24 17:12 yiyixuxu

See https://github.com/huggingface/diffusers/issues/10112#issuecomment-2519752589 it needs to pass a different sigma to model than is used for sampling

hlky avatar Dec 11 '24 17:12 hlky

i see!

yiyixuxu avatar Dec 11 '24 17:12 yiyixuxu

example usage for offloading + modular diffusers


import torch
import os
from diffusers.utils import load_image
from diffusers import  StableDiffusionXLModularPipeline, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_modular import StableDiffusionXLTextEncoderStep, StableDiffusionXLDecodeLatentsStep, StableDiffusionXLInputStep, StableDiffusionXLAutoSetTimestepsStep, StableDiffusionXLAutoPrepareLatentsStep, StableDiffusionXLAutoPrepareAdditionalConditioningStep, StableDiffusionXLAutoDenoiseStep
from diffusers.pipelines.modular_pipeline_builder import SequentialPipelineBlocks



# define device and dtype
device = "cuda:2"
dtype = torch.float16

# define output folder
out_folder = "modular_test_outputs_all_pag_1214_offload"
os.makedirs(out_folder, exist_ok=True)

# functions for memory info
def reset_memory():
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

def print_memory(node_name):
    mem_bytes = torch.cuda.max_memory_allocated(device=device)
    print(f" {node_name} memory allocated: {mem_bytes / (1024 ** 3):.2f} GB")


# (1)Define inputs
prompt = "a photo of an astronaut riding a horse on mars"

# for img2img
url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png"
init_image = load_image(url).convert("RGB")
strength = 0.9 


# (2) define workflows 

# use pre-defined text encoder step
text_encoder_workflow = StableDiffusionXLTextEncoderStep()

# define `SequentialPipelineBlocks` that contains all steps except text encoder
class StableDiffusionXLMainSteps(SequentialPipelineBlocks):
    block_classes = [
        StableDiffusionXLInputStep, 
        StableDiffusionXLAutoSetTimestepsStep, 
        StableDiffusionXLAutoPrepareLatentsStep, 
        StableDiffusionXLAutoPrepareAdditionalConditioningStep, 
        StableDiffusionXLAutoDenoiseStep, 
        StableDiffusionXLDecodeLatentsStep,
    ]
    block_prefixes = [
        "input", 
        "set_timesteps", 
        "prepare_latents", 
        "prepare_add_cond", 
        "denoise", 
        "decode_latents",
    ]

sdxl_workflow = StableDiffusionXLMainSteps()


# (3) add states to workflows
repo = "stabilityai/stable-diffusion-xl-base-1.0"
text_encoder_workflow.add_states_from_pretrained(repo, torch_dtype=dtype)
sdxl_workflow.add_states_from_pretrained(repo, torch_dtype=dtype)


# (4) add workflow to builder so it can be run
text_node = StableDiffusionXLModularPipeline()
text_node.add_blocks(text_encoder_workflow)

sdxl_node = StableDiffusionXLModularPipeline()
sdxl_node.add_blocks(sdxl_workflow)



# (5) run the workflows

# (5.1) using text_node to generate text embeddings
reset_memory()
# we do not apply offload for text_node
text_node.to(device)
# # if you want to apply offload for text_node, you can use the following code
# text_node.enable_model_cpu_offload(device=device)

text_state = text_node.run_blocks(prompt=prompt)
print_memory("text_node")
# manually move the node to cpu
text_node.to("cpu")
print(f" text_node: {text_node}")


# (5.2) using sdxl_node to generate images
# apply offload for sdxl_node
sdxl_node.enable_model_cpu_offload(device=device)


# (5.2.1) text2img
generator = torch.Generator(device="cuda").manual_seed(0)
state = sdxl_node.run_blocks(**text_state.intermediates, generator=generator)
state.get_output("images").images[0].save(f"{out_folder}/test1_out_text2img.png")
print_memory("sdxl_node")



# (5.2.2) img2img
generator = torch.Generator(device="cuda").manual_seed(0)
state = sdxl_node.run_blocks(**text_state.intermediates, image=init_image, strength=strength, generator=generator)
state.get_output("images").images[0].save(f"{out_folder}/test2_out_img2img.png")
# here you see an increase in memory because vae is used before and after the denoising step and we do not offload it after first use
print_memory("sdxl_node")

# no need to move sdxl_node to cpu, it is already offloaded
print(f" sdxl_node: {sdxl_node}")

output

Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:02<00:00,  3.45it/s]
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:01<00:00,  3.80it/s]
 text_node memory allocated: 1.55 GB
Pipelines loaded with `dtype=torch.float16` cannot run with `cpu` device. It is not recommended to move them to `cpu` as running them will fail. Please make sure to use an accelerator to run the pipeline in inference, due to the lack of support for`float16` operations on this device in PyTorch. Please, remove the `torch_dtype=torch.float16` argument, or use another device for inference.
Pipelines loaded with `dtype=torch.float16` cannot run with `cpu` device. It is not recommended to move them to `cpu` as running them will fail. Please make sure to use an accelerator to run the pipeline in inference, due to the lack of support for`float16` operations on this device in PyTorch. Please, remove the `torch_dtype=torch.float16` argument, or use another device for inference.
 text_node: CustomPipeline Configuration:
==============================

Pipeline Blocks:
----------------
0. StableDiffusionXLTextEncoderStep - (CPU offload seq: text_encoder->text_encoder_2)

   -> prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds


Registered Components:
----------------------
text_encoder: CLIPTextModel (dtype=torch.float16, device=cpu)
tokenizer_2: CLIPTokenizer
text_encoder_2: CLIPTextModelWithProjection (dtype=torch.float16, device=cpu)
tokenizer: CLIPTokenizer

Registered Auxiliaries:
----------------------

Registered Configs:
------------------
force_zeros_for_empty_prompt: True

Default Call Parameters:
------------------------
prompt: None
prompt_2: None
negative_prompt: None
negative_prompt_2: None
cross_attention_kwargs: None
prompt_embeds: None
negative_prompt_embeds: None
pooled_prompt_embeds: None
negative_pooled_prompt_embeds: None
num_images_per_prompt: 1
guidance_scale: 5.0
clip_skip: None

Required Call Parameters:
--------------------------

Note: These are the default values. Actual values may be different when running the pipeline.
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:25<00:00,  1.94it/s]
 sdxl_node memory allocated: 5.21 GB
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 45/45 [00:12<00:00,  3.70it/s]
 sdxl_node memory allocated: 5.31 GB
 sdxl_node: CustomPipeline Configuration:
==============================

Pipeline Blocks:
----------------
0. StableDiffusionXLMainSteps - (CPU offload seq: unet->vae)
    • input_step (StableDiffusionXLInputStep) 
    • set_timesteps_step (StableDiffusionXLAutoSetTimestepsStep) 
    • prepare_latents_step (StableDiffusionXLAutoPrepareLatentsStep) 
    • prepare_add_cond_step (StableDiffusionXLAutoPrepareAdditionalConditioningStep) 
    • denoise_step (StableDiffusionXLAutoDenoiseStep) 
    • decode_latents_step (StableDiffusionXLDecodeLatentsStep) 

   prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, negative_prompt_embeds -> timesteps, negative_add_time_ids, images, timestep_cond, add_time_ids, num_inference_steps, latents, batch_size, latent_timestep


Registered Components:
----------------------
controlnet: NoneType
unet: UNet2DConditionModel (dtype=torch.float16, device=cpu)
guider: CFGGuider
controlnet_guider: CFGGuider
scheduler: EulerDiscreteScheduler
vae: AutoencoderKL (dtype=torch.float16, device=cpu)

Registered Auxiliaries:
----------------------
image_processor: VaeImageProcessor
control_image_processor: VaeImageProcessor

Registered Configs:
------------------
requires_aesthetics_score: False

Default Call Parameters:
------------------------
prompt: None
prompt_embeds: None
num_inference_steps: 50
timesteps: None
sigmas: None
denoising_end: None
device: None
strength: 0.3
denoising_start: None
num_images_per_prompt: 1
height: None
width: None
generator: None
latents: None
dtype: None
image: None
original_size: None
target_size: None
negative_original_size: None
negative_target_size: None
crops_coords_top_left: (0, 0)
negative_crops_coords_top_left: (0, 0)
guidance_scale: 5.0
aesthetic_score: 6.0
negative_aesthetic_score: 2.0
guidance_rescale: 0.0
cross_attention_kwargs: None
eta: 0.0
guider_kwargs: None
control_image: None
control_guidance_start: 0.0
control_guidance_end: 1.0
controlnet_conditioning_scale: 1.0
guess_mode: False
output_type: 'pil'
return_dict: True

Required Call Parameters:
--------------------------
prompt_embeds: 
pooled_prompt_embeds: 
negative_pooled_prompt_embeds: 
negative_prompt_embeds: 

Note: These are the default values. Actual values may be different when running the pipeline.

yiyixuxu avatar Dec 14 '24 20:12 yiyixuxu

Great initiative! I wanted to add one point that using the current inputs and intermediate_inputs and outputs is extremely verbose and too repetitive specially in the call function. I would purpose to use maybe (inline) dataclasses instead which makes it easier to use. Here is one example of what I mean:

class StableDiffusionXLDetailDenoiseStep(PipelineBlock):
    expected_components = ["unet", "scheduler", "guider"]

    @dataclass
    class Inputs:
        guidance_scale: float = 5.0
        guidance_rescale: float = 0.0
        cross_attention_kwargs: dict = None
        generator: torch.Generator = None
        eta: float = 0.0
        guider_kwargs: dict = None
        detail_adjustment: float = 0.0

    @dataclass
    class IntermediateInputs:
        latents: torch.Tensor
        timesteps: torch.Tensor
        num_inference_steps: int
        pooled_prompt_embeds: torch.Tensor
        negative_pooled_prompt_embeds: torch.Tensor
        add_time_ids: torch.Tensor
        negative_add_time_ids: torch.Tensor
        timestep_cond: torch.Tensor
        prompt_embeds: torch.Tensor
        negative_prompt_embeds: torch.Tensor

    @property
    def inputs(self) -> Inputs:
        return self.Inputs(...)

    @property
    def intermediates_inputs(self) -> IntermediateInputs:
        return self.IntermediateInputs(...)

    @property
    def intermediates_outputs(self) -> List[str]:
        return ["latents"]

    def __init__(self, unet=None, scheduler=None, guider=None):
        if guider is None:
            guider = CFGGuider()
        super().__init__(unet=unet, scheduler=scheduler, guider=guider)

    @torch.no_grad()
    def __call__(self, pipeline, state: PipelineState) -> PipelineState:

        disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False

        # adding default guider arguments: do_classifier_free_guidance, guidance_scale, guidance_rescale
        guider_kwargs = self.inputs.guider_kwargs or {}
        guider_kwargs = {
            **guider_kwargs,
            "disable_guidance": disable_guidance,
            "guidance_scale": self.inputs.guidance_scale,
            "guidance_rescale": self.inputs.guidance_rescale,
            "batch_size": self.intermediates_inputs.batch_size,
        }

        pipeline.guider.set_guider(pipeline, guider_kwargs)
        # Prepare conditional inputs using the guider
        prompt_embeds = pipeline.guider.prepare_input(
            self.intermediates_inputs.prompt_embeds,
            self.intermediates_inputs.negative_prompt_embeds,
        )
        add_time_ids = pipeline.guider.prepare_input(
            self.intermediates_inputs.add_time_ids,
            self.intermediates_inputs.negative_add_time_ids,
        )
        pooled_prompt_embeds = pipeline.guider.prepare_input(
            self.intermediates_inputs.pooled_prompt_embeds,
            self.intermediates_inputs.negative_pooled_prompt_embeds,
        )

        added_cond_kwargs = {
            "text_embeds": pooled_prompt_embeds,
            "time_ids": add_time_ids,
        }

        # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
        extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, self.inputs.eta)
        num_warmup_steps = max(len(self.intermediates_inputs.timesteps) - self.intermediates_inputs.num_inference_steps * pipeline.scheduler.order, 0)

        with pipeline.progress_bar(total=self.intermediates_inputs.num_inference_steps) as progress_bar:
            for i, t in enumerate(self.intermediates_inputs.timesteps):
                # expand the latents if we are doing classifier free guidance
                latent_model_input = pipeline.guider.prepare_input(latents, latents)
                latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t)
                # predict the noise residual
                noise_pred = pipeline.unet(
                    latent_model_input,
                    t * (1.0 + self.inputs.detail_adjustment),
                    encoder_hidden_states=prompt_embeds,
                    timestep_cond=self.intermediates_inputs.timestep_cond,
                    cross_attention_kwargs=self.intermediates_inputs.cross_attention_kwargs,
                    added_cond_kwargs=added_cond_kwargs,
                    return_dict=False,
                )[0]
                # perform guidance
                noise_pred = pipeline.guider.apply_guidance(
                    noise_pred,
                    timestep=t,
                    latents=latents,
                )
                # compute the previous noisy sample x_t -> x_t-1
                latents_dtype = latents.dtype
                latents = pipeline.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
                if latents.dtype != latents_dtype:
                    if torch.backends.mps.is_available():
                        # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
                        latents = latents.to(latents_dtype)

                if i == len(self.intermediates_inputs.timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0):
                    progress_bar.update()

        pipeline.guider.reset_guider(pipeline)
        state.add_intermediate("latents", latents)

        return pipeline, state

The difference is that now i can just refer the inputs via self.inputs.INPUT NAME instead of creating yet another variable to assign them. This simplifies code a lot imo. There may be somethings I am missing but I thought I'd mention it since I'd love a more flexible approach in diffusers.

lordsoffallen avatar Dec 25 '24 08:12 lordsoffallen

@lordsoffallen thanks! it was written that way so that it's easier to know which inputs are needed but your feedback totally make sense and we will make it less verbose:)

yiyixuxu avatar Dec 28 '24 08:12 yiyixuxu

testing script to use from the latest commit (will keep this one up to date from now on) cc @hlky @asomoza now have a auto workflow that supports any combination of text2img, img2img, inpaint, controlnet, controlnet-union, pag, APG, lora

testing script for modular diffusers (most updated)
import os
import torch
import numpy as np
import cv2
from PIL import Image

from diffusers import (
    ControlNetModel,
    ModularPipeline,
    UNet2DConditionModel,
    AutoencoderKL,
    ControlNetUnionModel,
)
from diffusers.utils import load_image
from diffusers.guider import PAGGuider, CFGGuider, APGGuider
from diffusers.pipelines.modular_pipeline import SequentialPipelineBlocks
from diffusers.pipelines.components_manager import ComponentsManager
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_modular import AUTO_BLOCKS, IMAGE2IMAGE_BLOCKS 

from controlnet_aux import LineartAnimeDetector

import logging
logging.getLogger().setLevel(logging.INFO)
logging.getLogger("diffusers").setLevel(logging.INFO)


# define device and dtype
device = "cuda:0"
dtype = torch.float16


# define output folder
out_folder = "modular_test_outputs_0110"
os.makedirs(out_folder, exist_ok=True)

# functions for memory info
def reset_memory():
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

def clear_memory():
    torch.cuda.empty_cache()

def print_mem(mem_size, name):
    mem_gb = mem_size / 1024**3
    mem_mb = mem_size / 1024**2
    print(f"- {name}: {mem_gb:.2f} GB ({mem_mb:.2f} MB)")

def print_memory(message=None):
    """
    Print detailed GPU memory statistics for a specific device.
    
    Args:
        device_id (int): GPU device ID
    """
    allocated_mem = torch.cuda.memory_allocated(device)
    reserved_mem = torch.cuda.memory_reserved(device)
    mem_on_device = torch.cuda.mem_get_info(device)[0]
    peak_mem = torch.cuda.max_memory_allocated(device)

    print(f"\nGPU:{device} Memory Status {message}:")
    print_mem(allocated_mem, "allocated memory")
    print_mem(reserved_mem, "reserved memory")
    print_mem(peak_mem, "peak memory")
    print_mem(mem_on_device, "mem on device")

# function to make canny image (for controlnet)
def make_canny(image):
    image = np.array(image)
    image = cv2.Canny(image, 100, 200)
    image = image[:, :, None]
    image = np.concatenate([image, image, image], axis=2)
    return Image.fromarray(image)


# (1)Define inputs
# for text2img/img2img
prompt = "a photo of an astronaut riding a horse on mars"

# for img2img
url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png"
init_image = load_image(url).convert("RGB")
strength = 0.9 

# for controlnet
control_image = make_canny(init_image)
controlnet_conditioning_scale = 0.5  # recommended for good generalization
# for controlnet_union
processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
controlnet_union_image = processor(init_image, output_type="pil")

# for inpainting
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"

inpaint_image = load_image(img_url).resize((1024, 1024))
inpaint_mask = load_image(mask_url).resize((1024, 1024))
inpaint_control_image = make_canny(inpaint_image)
inpaint_strength = 0.99


# (2) define blocks and nodes(builder)   

all_blocks_map = AUTO_BLOCKS.copy()
# text block
text_block = all_blocks_map.pop("text_encoder")()
# image encoder block
image_encoder_block = all_blocks_map.pop("image_encoder")()
# decoder block
decoder_block = all_blocks_map.pop("decode")()

class SDXLAutoBlocks(SequentialPipelineBlocks):
    block_classes = list(all_blocks_map.values())
    block_names = list(all_blocks_map.keys())
# sdxl main block
sdxl_auto_blocks = SDXLAutoBlocks()



image2image_blocks_map = IMAGE2IMAGE_BLOCKS.copy()
# we do not need image_encoder for refiner becuase it takes image_latents (from another pipeline) as input
_ = image2image_blocks_map.pop("image_encoder")()
# refiner block
class RefinerSteps(SequentialPipelineBlocks):
    block_classes = list(image2image_blocks_map.values())
    block_names = list(image2image_blocks_map.keys())
refiner_block = RefinerSteps()

text_node = ModularPipeline.from_block(text_block)
image_node = ModularPipeline.from_block(image_encoder_block)
sdxl_node = ModularPipeline.from_block(sdxl_auto_blocks)
decoder_node = ModularPipeline.from_block(decoder_block)
refiner_node = ModularPipeline.from_block(refiner_block)


# (3) add states to nodes
repo = "stabilityai/stable-diffusion-xl-base-1.0"
refiner_repo = "stabilityai/stable-diffusion-xl-refiner-1.0"
controlnet_repo = "diffusers/controlnet-canny-sdxl-1.0"
inpaint_repo = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
vae_fix_repo = "madebyollin/sdxl-vae-fp16-fix"
controlnet_union_repo = "brad-twinkl/controlnet-union-sdxl-1.0-promax"

components = ComponentsManager()
components.add_from_pretrained(repo, torch_dtype=dtype)


controlnet = ControlNetModel.from_pretrained(controlnet_repo, torch_dtype=dtype)
refiner_unet = UNet2DConditionModel.from_pretrained(refiner_repo, subfolder="unet", torch_dtype=dtype)
inpaint_unet = UNet2DConditionModel.from_pretrained(inpaint_repo, subfolder="unet", torch_dtype=dtype)
vae_fix = AutoencoderKL.from_pretrained(vae_fix_repo, torch_dtype=dtype)
controlnet_union = ControlNetUnionModel.from_pretrained(controlnet_union_repo, torch_dtype=dtype)

components.add("controlnet", controlnet)
components.add("refiner_unet", refiner_unet)
components.add("inpaint_unet", inpaint_unet)
components.add("controlnet_union", controlnet_union)
components.add("vae_fix", vae_fix)

# you can add guiders to manager too but no need because it was not serialized
pag_guider = PAGGuider(pag_applied_layers="mid")
controlnet_pag_guider = PAGGuider(pag_applied_layers="mid")

# load components/config into nodes
text_node.update_states(**components.get(["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"]))
image_node.update_states(**components.get(["vae"]))
decoder_node.update_states(vae=components.get("vae"))

sdxl_node.update_states(**components.get(["unet", "scheduler", "vae", "controlnet"]))
refiner_node.update_states(**components.get(["text_encoder_2","tokenizer_2", "vae", "scheduler"]), unet=components.get("refiner_unet"), force_zeros_for_empty_prompt=True, requires_aesthetics_score=True)



# (4) enable auto cpu offload: automatically offload models when available gpu memory go below a certain threshold
components.enable_auto_cpu_offload(device=device)
print(components)
reset_memory()


# (5) run the workflows
print(f" ")
print(f" text_node:")
print(text_node)
print(f" ")
print(f" generating text embeddings with text_node")
# using text_node to generate text embeddings
text_state = text_node(prompt=prompt)
print(" ")
print(f" components info after run text_node: text_encoder and text_encoder_2 are on device")
print(components)
print(f" ")
print(f" text_state info")
print(text_state)
print(" ")


# using sdxl_node to generate images

# to get info about sdxl_node and how to use it: inputs/outputs/components
# this is an "auto" workflow that works for all use cases: text2img, img2img, inpainting, controlnet, etc.
# so the information might not be super useful for your specific use case, you will find a "trigger inputs" section says this
#   Trigger Inputs:
#   --------------
#   This pipeline contains dynamic blocks that are selected at runtime based on your inputs.
#   • Trigger inputs: {'control_image', 'image_latents', 'mask'}
#   • Use .pipeline_block.get_triggered_blocks(*inputs) to see which blocks will be used for specific inputs
#   • Use .pipeline_block.get_triggered_blocks() to see blocks will be used for default inputs (when no trigger inputs are provided)
print(f" ")
print(f" sdxl_node:")
print(sdxl_node)
print(" ")

# since we want to use text2img use case, we can run the following to see components/blocks/inputs for this use case
print(f" ")
print(f" sdxl_node info (default use case: text2img)")
print(sdxl_node.pipeline_block.get_triggered_blocks())
print(" ")


# test1: text2img use case
# when you run the auto workflow, you will get these logs telling you which blocks are actuallyrunning
# (should match what the sdxl_node told you)
# Running block: StableDiffusionXLBeforeDenoiseStep, trigger: None
# Running block: StableDiffusionXLDenoiseStep, trigger: None
# Running block: StableDiffusionXLDecodeStep, trigger: None

generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    generator=generator, 
    output="latents"
)
images_output = decoder_node(latents=latents, output="images")
images_output.images[0].save(f"{out_folder}/test1_out_text2img.png")
print(f" save modular output to {out_folder}/test1_out_text2img.png")

clear_memory()

# test2: text2img with lora use case
print(f" ")
print(f" running test2: text2img with lora use case")
generator = torch.Generator(device="cuda").manual_seed(0)
sdxl_node.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors")
latents = sdxl_node(
    **text_state.intermediates, 
    generator=generator, 
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
images_output.images[0].save(f"{out_folder}/test2_out_text2img_lora.png")
print(f" save modular output to {out_folder}/test2_out_text2img_lora.png")


# test3:text2image without lora again, with pag
print(f" ")
print(f" running test3:text2image without lora again, with pag")
sdxl_node.unload_lora_weights()
sdxl_node.update_states(guider=pag_guider, controlnet_guider=controlnet_pag_guider)
generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    generator=generator, 
    guider_kwargs={"pag_scale": 3.0},
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
images_output.images[0].save(f"{out_folder}/test3_out_text2img_pag.png")
print(f" save modular output to {out_folder}/test3_out_text2img_pag.png")

clear_memory()
# checkout the components if you want, the models used is moved to devicem some might get offloaded to cpu
# print(components)

# test4: SDXL(text2img) with controlnet

# we are going to pass a new input now `control_image` so the workflow will be automatically converted to controlnet use case
# let's checkout the info for controlnet use case
print(f" sdxl_node info (controlnet use case)")
print(sdxl_node.pipeline_block.get_triggered_blocks("control_image"))
print(" ")

print(f" ")
print(f" running test4: SDXL(text2img) with controlnet")

generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    control_image=control_image, 
    controlnet_conditioning_scale=controlnet_conditioning_scale, 
    guider_kwargs={"pag_scale": 3.0}, 
    generator=generator,
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
images_output.images[0].save(f"{out_folder}/test4_out_text2img_control.png")
print(f" save modular output to {out_folder}/test4_out_text2img_control.png")

clear_memory()

# test5: SDXL(img2img)

# for img2img use case, we encode the image with image_node first, this way we can use the same image_latents for different workflows
# let's checkout the image_node
print(f" image_node info")
print(image_node)
print(" ")


print(f" ")
print(f" running test5: SDXL(img2img)")

generator = torch.Generator(device="cuda").manual_seed(0)
image_state = image_node(image=init_image, generator=generator)

# let's checkout what's in image_state
print(f" image_state info")
print(image_state)
print(" ")

# let's checkout the sdxl_node info for img2img use case
print(f" sdxl_node info (img2img use case)")
print(sdxl_node.pipeline_block.get_triggered_blocks("image_latents"))
print(" ")

latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    strength=strength, 
    guider_kwargs={"pag_scale": 3.0}, 
    generator=generator, 
    output="latents"
)
images_output = decoder_node(latents=latents, output="images")
images_output.images[0].save(f"{out_folder}/test5_out_img2img.png")
print(f" save modular output to {out_folder}/test5_out_img2img.png")

clear_memory()

# test6: SDXL(img2img) with controlnet
# let's checkout the sdxl_node info for img2img controlnet use case
print(f" sdxl_node info (img2img controlnet use case)")
print(sdxl_node.pipeline_block.get_triggered_blocks("image_latents","control_image"))
print(" ")

print(f" ")
print(f" running test6: SDXL(img2img) with controlnet")

generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    control_image=control_image, 
    guider_kwargs={"pag_scale": 3.0}, 
    controlnet_conditioning_scale=controlnet_conditioning_scale, 
    strength=strength, 
    generator=generator, 
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
images_output.images[0].save(f"{out_folder}/test6_out_img2img_control.png")
print(f" save modular output to {out_folder}/test6_out_img2img_control.png")

clear_memory()

# test7: img2img with refiner

# let's checkout the refiner_node
print(f" refiner_node info")
print(refiner_node)
print(" ")

print(f" ")
print(f" running test7: img2img with refiner")

images_output = refiner_node(
    image_latents=latents, 
    prompt=prompt, 
    denoising_start=0.8, 
    generator=generator, 
    output="images"
)
images_output.images[0].save(f"{out_folder}/test7_out_img2img_refiner.png")
print(f" save modular output to {out_folder}/test7_out_img2img_refiner.png")

clear_memory()

# test8: SDXL(inpainting)
# let's checkout the sdxl_node info for inpainting use case
print(f" sdxl_node info (inpainting use case)")
print(sdxl_node.pipeline_block.get_triggered_blocks("mask", "image_latents"))
print(" ")

print(f" ") 
print(f" running test8: SDXL(inpainting)")

generator = torch.Generator(device="cuda").manual_seed(0)
image_state = image_node(image=inpaint_image, mask_image=inpaint_mask, height=1024, width=1024, generator=generator)
print(f" image_state info")
print(image_state)
print(" ")
latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    guider_kwargs={"pag_scale": 3.0}, 
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    output="latents"
)
images_output = decoder_node(latents=latents, output="images")
images_output.images[0].save(f"{out_folder}/test8_out_inpainting.png")
print(f" save modular output to {out_folder}/test8_out_inpainting.png")

clear_memory()

# test9: SDXL(inpainting) with controlnet
# let's checkout the sdxl_node info for inpainting + controlnet use case
print(f" sdxl_node info (inpainting + controlnet use case)")
print(sdxl_node.pipeline_block.get_triggered_blocks("mask", "control_image"))
print(" ")

print(f" ") 
print(f" running test9: SDXL(inpainting) with controlnet")

generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    control_image=control_image, 
    guider_kwargs={"pag_scale": 3.0}, 
    controlnet_conditioning_scale=controlnet_conditioning_scale, 
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    output="latents"
)
images_output = decoder_node(latents=latents, output="images")
images_output.images[0].save(f"{out_folder}/test9_out_inpainting_control.png")
print(f" save modular output to {out_folder}/test9_out_inpainting_control.png")

clear_memory()

# test10: SDXL(inpainting) with inpaint_unet
print(f" ") 
print(f" running test10: SDXL(inpainting) with inpaint_unet")

sdxl_node.update_states(unet=components.get("inpaint_unet"))
generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    guider_kwargs={"pag_scale": 3.0}, 
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    output="latents"
)
images_output = decoder_node(latents=latents, output="images")
images_output.images[0].save(f"{out_folder}/test10_out_inpainting_inpaint_unet.png")
print(f" save modular output to {out_folder}/test10_out_inpainting_inpaint_unet.png")

clear_memory()


# test11: SDXL(inpainting) with inpaint_unet + padding_mask_crop
print(f" ") 
print(f" running test11: SDXL(inpainting) with inpaint_unet (padding_mask_crop=33)")

generator = torch.Generator(device="cuda").manual_seed(0)
image_state = image_node(image=inpaint_image, mask_image=inpaint_mask, height=1024, width=1024, generator=generator, padding_mask_crop=33)
print(f" image_state info")
print(image_state)
print(" ")

latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    guider_kwargs={"pag_scale": 3.0}, 
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    output="latents"
)

# we need a different decoder when using padding_mask_crop
print(f" decoder_node info")
print(decoder_node)
print(" ")
print(f" decoder_node info (inpaint/padding_mask_crop)")
print(decoder_node.pipeline_block.blocks["inpaint"])
print(" ")

images_output = decoder_node(latents=latents, crops_coords=image_state.get_intermediate("crops_coords"), **image_state.inputs, output="images")
images_output.images[0].save(f"{out_folder}/test11_out_inpainting_inpaint_unet_padding_mask_crop.png")
print(f" save modular output to {out_folder}/test11_out_inpainting_inpaint_unet_padding_mask_crop.png")

clear_memory()


# test12: apg

print(f" ")
print(f" running test12: apg")

apg_guider = APGGuider()
sdxl_node.update_states(guider=apg_guider, unet=components.get("unet"))


generator = torch.Generator().manual_seed(0)
latents= sdxl_node(
  **text_state.intermediates,
  generator=generator,
  num_inference_steps=20,
  guidance_scale=15,
  height=896,
  width=768,
  guider_kwargs={
      "adaptive_projected_guidance_momentum": -0.5,
      "adaptive_projected_guidance_rescale_factor": 15.0,
  },
  output="latents"
)


images_output = decoder_node(latents=latents, output="images")
images_output.images[0].save(f"{out_folder}/test12_out_apg.png")
print(f" save modular output to {out_folder}/test12_out_apg.png")

clear_memory()


# test13: SDXL(text2img) with controlnet_union

sdxl_node.update_states(controlnet=components.get("controlnet_union"), unet=components.get("unet"), guider=pag_guider, controlnet_guider=controlnet_pag_guider)
image_node.update_states(vae=components.get("vae_fix"))
decoder_node.update_states(vae=components.get("vae_fix"))

# we are going to pass a new input now `control_mode` so the workflow will be automatically converted to controlnet use case
# let's checkout the info for controlnet use case
print(f" sdxl_node info (controlnet union use case)")
print(sdxl_node.pipeline_block.get_triggered_blocks("control_mode"))
print(" ")

print(f" ")
print(f" running test13: SDXL(text2img) with controlnet_union")

generator = torch.Generator(device="cuda").manual_seed(0)

latents = sdxl_node(
    **text_state.intermediates, 
    control_mode=[3],
    control_image=[controlnet_union_image], 
    height=1024,
    width=1024,
    generator=generator,
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
images_output.images[0].save(f"{out_folder}/test13_out_text2img_control_union.png")
print(f" save modular output to {out_folder}/test13_out_text2img_control_union.png")

clear_memory()


# test14: SDXL(img2img) with controlnet_union
print(f" image_node info(with vae_fix for controlnet union)")
print(image_node)
print(" ")


print(f" ")
print(f" sdxl_node info (img2img controlnet union use case)")
print(sdxl_node.pipeline_block.get_triggered_blocks("image_latents", "control_mode"))
print(" ")

print(f" ")
print(f" running test14: SDXL(img2img) with controlnet_union")

generator = torch.Generator(device="cuda").manual_seed(0)
image_state = image_node(image=init_image, generator=generator)

print(f" image_state info")
print(image_state)
print(" ")

latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    control_mode=[3],
    control_image=[controlnet_union_image], 
    height=1024,
    width=1024,
    generator=generator,
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
images_output.images[0].save(f"{out_folder}/test14_out_img2img_control_union.png")
print(f" save modular output to {out_folder}/test14_out_img2img_control_union.png")

clear_memory()

# test15: SDXL(inpainting) with controlnet_union
print(f" ")
print(f" sdxl_node info (inpainting controlnet union use case)")
print(sdxl_node.pipeline_block.get_triggered_blocks("mask", "control_mode"))
print(" ")

print(f" ")
print(f" running test15: SDXL(inpainting) with controlnet_union")

generator = torch.Generator(device="cuda").manual_seed(0)
image_state = image_node(image=inpaint_image, mask_image=inpaint_mask, height=1024, width=1024, generator=generator)

print(f" image_state info")
print(image_state)
print(" ")

latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    control_mode=[3],
    control_image=[controlnet_union_image], 
    height=1024,
    width=1024,
    generator=generator,
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
images_output.images[0].save(f"{out_folder}/test15_out_inpainting_control_union.png")
print(f" save modular output to {out_folder}/test15_out_inpainting_control_union.png")

clear_memory()

print_memory("the end")

print(f" components info after the end")
print(components)

yiyixuxu avatar Jan 12 '25 15:01 yiyixuxu

@hlky

We can remove everything related to num_images_per_prompt as its handled by StableDiffusionXLInputStep and I think we could make these functions work with a single input then call separately with positive and negative prompt/image from the module

totally agree, I was thinking about that too! do you want to take a stab on that? we need to refactor these functions from regular pipeline too

yiyixuxu avatar Feb 04 '25 17:02 yiyixuxu

@yiyixuxu Yes I'll work on that

hlky avatar Feb 04 '25 19:02 hlky

Super cool @yiyixuxu @asomoza @hlky! Not reviewing the PR yet since I'm getting a feel for how a developer would be interacting with the library, but I personally found it very intuitive to get started from the examples.

Here's my first try at making a modular diffusers workflow for naive latent upscaling with SDXL:

Code
import torch
import torch.nn.functional as F
from diffusers import ModularPipeline, StableDiffusionXLAutoPipeline
from diffusers.pipelines.components_manager import ComponentsManager

# Load models
components = ComponentsManager()
components.add_from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
components.enable_auto_cpu_offload(device="cuda:0")

# Create pipeline
pipe = ModularPipeline.from_block(StableDiffusionXLAutoPipeline())
pipe.update_states(**components.components)
pipe.to("cuda")

# Run inference
prompt = "A majestic lion jumping from a big stone at night"
height = 1024
width = 1024
output = pipe(prompt=prompt, height=height, width=width, num_inference_steps=30)
images = output.intermediates.get("images").images
latents = output.intermediates.get("latents")
images[0].save("output.png")

# Latent upscale
# Note that only naive upscaling is done here. Alternatively, a latent upscaler
# model could be used
batch_size, num_channels, latent_height, latent_width = latents.shape
scale_factor = 1.5
upscaled_height, upscaled_width = int(height * scale_factor), int(width * scale_factor)
upscaled_latent_height, upscaled_latent_width = int(latent_height * scale_factor), int(latent_width * scale_factor)
upscaled_latents = F.interpolate(latents, size=(upscaled_latent_height, upscaled_latent_width), mode="nearest-exact")

# Run inference with upscaled latents
strength = 0.5
upscaled_output = pipe(prompt=prompt, image_latents=upscaled_latents, height=upscaled_height, width=upscaled_width, num_inference_steps=40, strength=strength)

images = upscaled_output.intermediates.get("images").images
images[0].save("output_upscaled.png")

On my first try, I passed latents=upscaled_latents instead of image_latents=upscaled_latents, which does not work as expected (does not trigger the SDXL img2img blocks). Since I have the advantage of knowing the library beforehand, I could make an educated guess about the image_latents parameter or quickly find out by looking at the code.

I wonder if things like this may cause some friction in getting started with modular diffusers workflows. In this case, do you think renaming image_latents to latents is suitable choice to make? Not quite sure why the two are distinguished at the moment, but will take a look at the code soon to better understand.

a-r-r-o-w avatar Feb 05 '25 02:02 a-r-r-o-w

Question: Let's say we implemented a Flux/SD3 equivalent of the SDXL modular blocks. Now I want to do the same latent upscale thing in the above comment.

To make it possible to upscale latents with every supported model, I would like to create a general purpose node/block, with different possible init configurations, that takes a ndim=4 latent and upscales it based on the init configs - either naively or using a latent upscaler model. I expect this block to be invoked before the denoiser steps begin. Let's also assume that I have created the auto-pipe instances for both, similar to what's shown in the examples.

How would I go about inserting my custom blocks into the pipeline execution flow? Or, what would the plan of action on the developers' end look like if they want to inject some code before/after each atomic pipeline step that we currently have (vae encode/decode, latent prep, denoise step, ...)?

a-r-r-o-w avatar Feb 05 '25 02:02 a-r-r-o-w

@a-r-r-o-w we sort of have inconsistent parameter names across different pipelines right now, with modular, same parameters will need to be combined into one, so I guess we will have to pick a name to stick to

latent is one example

  • in text-to-image; latents is the initial random noise (which is useful to reproduce results)
  • in img2img,
    • latents it is prepared latents (image latent + noise): https://github.com/huggingface/diffusers/blob/5b1dcd15848f6748c6cec978ef962db391c4e4cd/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py#L1293
    • image parameter also accept the encoded image, so it is the image latents (without the added noise) https://github.com/huggingface/diffusers/blob/5b1dcd15848f6748c6cec978ef962db391c4e4cd/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py#L714

there is also image and control_image: it is called image for text-to-image controlnet but control_image in image-to-image when there is already an image variable

in your case for upscaling, I think it should be image_latents (currently image in our pipelines), no? (latent in general should include the noise, at least conceptually, even though in some case, we don't need to add noise to it at all in prepare_latent process). it is indeed very confusing. and I understand that the output latents has different meanings from input latents in our pipelines, that's not ideal,
maybe we can:

  • rename our current latents to init_noise or something
  • latents is the initial latents that used in denoising loop (it may or may not include noise) - it could be same as image_latents or init_noise
  • image_latent is the encoded image

open to suggestions/discussions

yiyixuxu avatar Feb 05 '25 18:02 yiyixuxu

@a-r-r-o-w

if it is an upscaler that takes latents as input, I think it is most convenient to be used on its own, (like in UI, it would be its own node/pipeline)

maybe make a map like this so it can be used to create different presets?

AUTO_UPSCALE_BLOCKS = OrderedDict([
    ("text_encoder", StableDiffusionXLTextEncoderStep),
    ("ip_adapter", StableDiffusionXLAutoIPAdapterStep),
    ("image_encoder", StableDiffusionXLAutoVaeEncoderStep),
    ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep),
    ("upscale", AutoUpscaleStep),
    ("denoise", StableDiffusionXLAutoDenoiseStep),
    ("decode", StableDiffusionXLAutoDecodeStep)
])

make a preset for end-to-end pipeline

class SDXLAutoUpscaleBlocks(SequentialPipelineBlocks):
    block_classes = list(AUTO_UPSCALE_BLOCKS.values())
    block_names = list(AUTO_UPSCALE_BLOCKS.keys())

auto_pipe_upscaled = ModularPipeline.from_block(SDXLAutoUpscaleBlocks())

just the upscaler node used in stand-alone

upscaler_block = AUTO_UPSCALE_BLOCKS["upscale"]()
upcaler_node = ModularPipeline.from_block(upscaler_block)

yiyixuxu avatar Feb 05 '25 18:02 yiyixuxu

Did a pass on the examples and the info shared instead of looking through the code too much (following @a-r-r-o-w's philosophy).

Some comments first.

auto_pipe.update_states(**components.components) -- should this be called auto_pipe.update_components()? update_states() seems a bit counterintuitive?

The pipeline automatically adapts to your inputs:

What if the user combines the inputs that are supported? How do we infer for such situations? For example, what if I provide a control_image and prompt?

print(auto_pipe)

This is very convenient! However, I wonder if the user could restrict the level of info they want to see. I got a bit lost after the args started appearing. Maybe something to consider in the later iterations.

Misc:

  • Similar to get_execution_blocks(), would it make sense to provide a list_execution_blocks() method?
  • intermediates seems to be a very useful attribute that could benefit from some explicit documentation.

Now, I tried to use the SDXL refiner:

Code
import torch
from diffusers import ModularPipeline, StableDiffusionXLAutoPipeline
from diffusers.pipelines.components_manager import ComponentsManager

# Load models
components = ComponentsManager()
components.add_from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)

# Create pipeline
pipe = ModularPipeline.from_block(StableDiffusionXLAutoPipeline())
pipe.update_states(**components.components)
pipe.to("cuda")

# Run inference
prompt = "A majestic lion jumping from a big stone at night"
height = 1024
width = 1024
output = pipe(prompt=prompt, height=height, width=width, num_inference_steps=30)
images = output.intermediates.get("images").images
latents = output.intermediates.get("latents")
print(f"{latents.shape=}")
images[0].save("output_modular.png")

# Clear things
del components, pipe
torch.cuda.empty_cache()

# Load refiner
components = ComponentsManager()
components.add_from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16)

# Create pipeline
pipe = ModularPipeline.from_block(StableDiffusionXLAutoPipeline())
pipe.update_states(**components.components)
pipe.to("cuda")
pipe.register_to_config(requires_aesthetics_score=False)

# Refine outputs.
output = pipe(prompt=prompt, image_latents=latents, num_inference_steps=30)
images = output.intermediates.get("images").images
images[0].save("output_refiner_modular.png")

It leads to:

ValueError: Model expects an added time embedding vector of length 2560, but a vector of 2816 was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` (1024, 1024) is correctly used by the model.

Questions:

  • What am I doing wrong?
  • Is there a better way of using the refiner with modular diffusers? SDXL base and refiner share some components but it wasn't clear to me how to make it work with a workflow similar to the Diff-Diff one. Some guidance would be nice.

sayakpaul avatar Feb 08 '25 05:02 sayakpaul

@sayakpaul these are really good feedbacks! thank you!

for refiner, you have to do

refiner_pipeline.update_states(**components.get(["text_encoder_2","tokenizer_2", "vae", "scheduler"]), unet=components.get("refiner_unet"), force_zeros_for_empty_prompt=True, requires_aesthetics_score=True)

it is a bit verbose as you can see, and it's the case in general on how we load the ModularPipeline, so I'm wildly open to suggestions for improvement in that aspect. One idea I want to play around with is to introduce "collection" on components manager (not a good name since it means something different on hub but just the idea to allow users to operates on a group of model components at once, with some pipeline config attached to it) - will push a POC soon

auto_pipe.update_states(**components.components) -- should this be called auto_pipe.update_components()

open to better API, but probably not components because we also update config with it

What if the user combines the inputs that are supported? How do we infer for such situations? For example, what if I provide a control_image and prompt?

open to suggestions on how to do better here, currently each pipelineblock has a description attribute and it is up to the developer to document about workflows that are supported and their respective inputs

This is very convenient! However, I wonder if the user could restrict the level of info they want to see. I got a bit lost after the args started appearing. Maybe something to consider in the later iterations.

These are pretty important! We don't have to wait to improve in later iterations. Let's make it better now if it's possible. maybe we don't have to print out the docstring (the args etc), we can direct user to use .doc to get them?

yiyixuxu avatar Feb 11 '25 07:02 yiyixuxu

Thanks Yiyi!

With your suggestion, I could successfully do my first outputs powered by modular diffusers:

updated code
import torch
from diffusers import ModularPipeline, StableDiffusionXLAutoPipeline
from diffusers.pipelines.components_manager import ComponentsManager

# Load models
components = ComponentsManager()
components.add_from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)

# Create pipeline
pipe = ModularPipeline.from_block(StableDiffusionXLAutoPipeline())
pipe.update_states(**components.components)
pipe.to("cuda")

# Run inference
prompt = "A majestic lion jumping from a big stone at night"
height = 1024
width = 1024
output = pipe(prompt=prompt, height=height, width=width, num_inference_steps=30)
images = output.intermediates.get("images").images
latents = output.intermediates.get("latents")
print(f"{latents.shape=}")
images[0].save("output_modular.png")

# Clear things
del components, pipe
torch.cuda.empty_cache()

# Load refiner
components = ComponentsManager()
components.add_from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16)

# Create pipeline
refiner_pipeline = ModularPipeline.from_block(StableDiffusionXLAutoPipeline())
refiner_pipeline.update_states(
    **components.get(["text_encoder_2", "tokenizer_2", "vae", "scheduler"]), 
    unet=components.get("unet"), 
    force_zeros_for_empty_prompt=True, 
    requires_aesthetics_score=True
)
refiner_pipeline.to("cuda")
# Refine outputs.
output = refiner_pipeline(prompt=prompt, image_latents=latents, num_inference_steps=30)
images = output.intermediates.get("images").images
images[0].save("output_refiner_modular.png")
refiner base
image image

it is a bit verbose as you can see, and it's the case in general on how we load the ModularPipeline, so I'm wildly open to suggestions for improvement in that aspect. One idea I want to play around with is to introduce "collection" on components manager (not a good name since it means something different on hub but just the idea to allow users to operates on a group of model components at once, with some pipeline config attached to it) - will push a POC soon

I am fine with verbosity if it teaches the user about how to correctly modify things. Maybe the error message could better reflect how to properly do the update_states() step if that seems feasible at all? Otherwise, it feels like guesswork (or perhaps I am not well-equipped to understand the flow yet).

open to better API, but probably not components because we also update config with it

Oh then. Then probably update_attributes()?

open to suggestions on how to do better here, currently each pipelineblock has a description attribute and it is up to the developer to document about workflows that are supported and their respective inputs

I noticed it after I commented that. I think this sufficient for now. (no strong opinions) Should we maybe enforce some kind of input validator (validate_inputs(), e.g.) so that different similar inputs don't interfere with each other's scopes?

These are pretty important! We don't have to wait to improve in later iterations. Let's make it better now if it's possible. maybe we don't have to print out the docstring (the args etc), we can direct user to use .doc to get them?

Perfect, this sounds very good!

sayakpaul avatar Feb 11 '25 09:02 sayakpaul

@sayakpaul I looked at the code you linked here I think you don't need to remove the components manager and reload everything again.

# Loading Models
components = ComponentsManager()
components.add_from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)

# load just the refiner UNet (reuse the text_encoders that's already in components)
+ refiner_unet = UNet2DConditionModel.from_pretrained(
+     "stabilityai/stable-diffusion-xl-refiner-1.0", 
+     subfolder="unet", 
+     torch_dtype=torch.float16
+ )
+ components.add("refiner_unet", refiner_unet)
# this make sure all models stay in cpu until forward pass is invoked and may be put back on cpu when more GPU memory is needed
+ components.enable_auto_cpu_offload()

# I think we don't need to do this:
# 1. pipe's states are managed by `components`; if we want to delete everything, delete components in components manager is enough
# 2. GPU memory is already managed by `components`, i.e. if we need more memory to run refiner pipeline,
#    the other unet from base repo will be offload to cpu.
#    We can also add methods to unload/delete models if more explicit control is needed but overall I think we don't need to 
#    delete a model unless we are certain we do not need them anymore
# 3. in this particular use case, we still need the text_encoders so don't recommend deleting them and reloading again here
- # Clear components and free CUDA memory before loading refiner
- del components, pipe
- torch.cuda.empty_cache()
- 
- # Load complete refiner pipeline
- components = ComponentsManager()
- components.add_from_pretrained(
-     "stabilityai/stable-diffusion-xl-refiner-1.0", 
-     torch_dtype=torch.float16
- )

# Refiner Pipeline Setup
refiner_pipeline = ModularPipeline.from_block(StableDiffusionXLAutoPipeline())
refiner_pipeline.update_states(
    **components.get(["text_encoder_2", "tokenizer_2", "vae", "scheduler"]),
+   unet=components.get("refiner_unet"),  # Using explicitly loaded UNet
-   unet=components.get("unet"),  # Using UNet from complete pipeline
    force_zeros_for_empty_prompt=True,
    requires_aesthetics_score=True
)
Click to expand the code
import torch
from diffusers import ModularPipeline, StableDiffusionXLAutoPipeline, UNet2DConditionModel
from diffusers.pipelines.components_manager import ComponentsManager

# Load models
components = ComponentsManager()
components.add_from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)

refiner_unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0", subfolder="unet", torch_dtype=torch.float16)
components.add("refiner_unet", refiner_unet)
components.enable_auto_cpu_offload()

# Create pipeline
pipe = ModularPipeline.from_block(StableDiffusionXLAutoPipeline())
pipe.update_states(**components.components)
pipe.to("cuda")

# Run inference
prompt = "A majestic lion jumping from a big stone at night"
height = 1024
width = 1024
output = pipe(prompt=prompt, height=height, width=width, num_inference_steps=30)
images = output.intermediates.get("images").images
latents = output.intermediates.get("latents")
print(f"{latents.shape=}")
images[0].save("output_modular.png")


# Create pipeline
refiner_pipeline = ModularPipeline.from_block(StableDiffusionXLAutoPipeline())
refiner_pipeline.update_states(
    **components.get(["text_encoder_2", "tokenizer_2", "vae", "scheduler"]), 
    unet=components.get("refiner_unet"), 
    force_zeros_for_empty_prompt=True, 
    requires_aesthetics_score=True
)
refiner_pipeline.to("cuda")
# Refine outputs.
output = refiner_pipeline(prompt=prompt, image_latents=latents, num_inference_steps=30)
images = output.intermediates.get("images").images
images[0].save("output_refiner_modular.png")

can you help me:

  1. look into if there any benefit in deleting the models when switching workflows?In general, I think it is more efficient to offload them to cpu when you work with multiple workflows but want to see if there is any use case we missed
  2. how can we do better in docs for this?

yiyixuxu avatar Feb 11 '25 20:02 yiyixuxu

@sayakpaul for the other comments

h then. Then probably update_attributes()?

could be just update too - I will keep this open since it will be very easy to change names later!

Should we maybe enforce some input validator (validate_inputs()

happy to explore this too, if you can share a POC that'd be great!

yiyixuxu avatar Feb 11 '25 20:02 yiyixuxu

look into if there any benefit in deleting the models when switching workflows? In general, I think it is more efficient to offload them to cpu when you work with multiple workflows but want to see if there is any use case we missed

I think this is a valid assumption except for the situations where we don't have enough CPU RAM (48GBs might be low).

how can we do better in docs for this?

I think we could cover the refiner use case (and alike) under the theme of "reusing components between workflows". We could make it clear that to make the most out of reusing, it's recommended to first load all the components needed for the workflows users want to try out and keep them on CPU. Users will always have the option to load any ad-hoc component component they may may have forgotten in the beginning. If we can make this clear in the docs with examples, I think that should be enough. WDYT?

could be just update too - I will keep this open since it will be very easy to change names later!

Yeah update() is potentially simpler. SGTM!

happy to explore this too, if you can share a POC that'd be great!

Sure, happy to do that. I will branch off of this PR and try to open a PR. Would that work?

sayakpaul avatar Feb 12 '25 03:02 sayakpaul

I finished testing and doing a PoC with the callbacks so I can update the step progress inside an UI. So discussing here a question about the implementation, since we now have the data object, I would love if we could pass around the whole object instead of the current method (which I found restrictive) where we need to enable which variables we want to expose to the callbacks but this won't be compatible with the current callbacks.

So I did this for the PoC to match current implementation:

                if data.callback_on_step_end is not None:
                    callback_kwargs = {}
                    for k in data.callback_on_step_end_tensor_inputs:
                        callback_kwargs[k] = getattr(data, k)
                    callback_outputs = data.callback_on_step_end(self, i, t, callback_kwargs)

                    data.latents = callback_outputs.pop("latents", data.latents)
                    data.prompt_embeds = callback_outputs.pop("prompt_embeds", data.prompt_embeds)
                    data.added_cond_kwargs["text_embeds"] = callback_outputs.pop("text_embeds", data.added_cond_kwargs["text_embeds"])
                    data.added_cond_kwargs["time_ids"] = callback_outputs.pop("time_ids", data.added_cond_kwargs["time_ids"])

but it could be something like this which is better to me:

                if data.callback_on_step_end is not None:
                    data.callback_on_step_end(self, i, t, data)

what are your thoughts on this @yiyixuxu?

asomoza avatar Feb 12 '25 18:02 asomoza

@asomoza second one for sure! that's the point, callback should be super easy now

                if data.callback_on_step_end is not None:
                    data.callback_on_step_end(self, i, t, data)

yiyixuxu avatar Feb 12 '25 18:02 yiyixuxu

@sayakpaul souds good!

Sure, happy to do that. I will branch off of this PR and try to open a PR. Would that work?

yiyixuxu avatar Feb 12 '25 18:02 yiyixuxu

It's looking really nice. Obviously there are a lot of intricacies here that I might not have picked up, so in my initial pass I just tried to focus on parts that felt a little unclear to me.

I tried to break it down by the major components in Modular Diffusers.

Components Manager

My understanding here is that Components Manager is responsible for loading all models, schedulers, etc into the Modular Pipeline and performing memory management for the loaded components.

Where it felt a bit unintuitive was trying determine which model repos can be used with add_from_pretrained and which ones cannot.

For example, This snippet will load all the components of the base SDXL Pipelines into Component Manager

# Load models
components = ComponentsManager()
components.add_from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
)

But if I want to load a ControlNet Model via a model repo I cannot. I have to create the object and add to Components Manager via the add method.

components.add_from_pretrained("xinsir/controlnet-union-sdxl-1.0", torch_dtype=torch.float16)

Since I'm familiar with the library, I realise that this is following our existing Pipeline loading logic. But I think it might make sense to support adding individual model components through add_from_pretrained as well. We may need to introduce AutoModel logic or something to make it happen.

PipelineBlock

My understanding here is that a PipelineBlock is not expected to load any models, but instead only runs a computation step using the preloaded models in the ComponentsManager or perhaps some custom code. I also this from a user perspective, most people building with Modular will mostly likely be developing new block types.

The PipelineBlock is also meant to be stateless and all stateful operations are managed through the PipelineState or BlockState?

Let's say I want to add a PipelineBlock that has a model associated with the step. In the example below I want to create block that automatically extracts a depth map from an image so that I can use it with a ControlNet.

Can I add the depth model to the ComponentManager from the block, in a manner similar to register_to_config? Or should I always add the model to the ComponentManager and then update the Block state? What is the correct way to create a block with an associated model?

class DepthBlock(PipelineBlock):
    @property
    def inputs(self) -> List[InputParam]:
        control_image = InputParam(
            name="control_image",
            required=True,
        )
        return control_image

    def __init__(self) -> None:
        super().__init__()
        # If I load in a model in pipeline block is it possible to move the the componets manager?
        depth_preprocessor = DepthPreprocessor.from_pretrained("depth-anything/Depth-Anything-V2-Large-hf")

    def __call__(self, pipeline, state: PipelineState) -> PipelineState:
        data = self.get_block_state(state)
        control_image = data.control_image
        depth_image = self.depth_processor(control_image)
        data.control_image = depth_image
        
        self.add_block_state(data, state)

        return pipeline, state 

When initializing PipelineBlocks we have this line

class StableDiffusionXLDecodeLatentsStep(PipelineBlock):
    expected_components = ["vae"]
    model_name = "stable-diffusion-xl"

And then in the __init__ we have

    def __init__(self):
        super().__init__()
        self.components["vae"] = None
        self.auxiliaries["image_processor"] = VaeImageProcessor(vae_scale_factor=8)

I found it a bit confusing as to why we are setting self.component["vae"] = None during the init of the PipelineBlock, because based on the class attribute it feels like it should be initialized with something? Additionally, self.components doesn't seem to be used anywhere in the __call__ so it's application or use feels a bit unclear.

Are the class attributes at the top of the block needed? As far as I can tell from skimming the code, we operate on block instances everywhere? Can we define PipelineBlocks in such a way? IMO a bit more Pythonic and makes the Blocks feel a bit more like mini-Pipelines. You can also add type enforcement check on the components too. LMK if I'm missing something here.

class StableDiffusionXLTextEncoderStep(PipelineBlock):
	def __init__(
		self,
		text_encoder=None,
		text_encoder_2=None,
		tokenizer=None,
		tokenizer_2=None,
		force_zeros_for_empty_prompt=True,
	):
		super().__init__()

        # this would set expected_configs
        self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
        
        # this would set expected_components
        self.register_component({
            text_encoder=text_encoder,
            text_encoder_2=text_encoder_2,
            tokenizer=tokenizer,
            tokenizer_2=tokenizer_2
        })

Another thing I wasn't quite able to figure out the exact scope of PipelineBlock. Should it operate as a single atomic unit or be aware of the global pipeline methods that are available?

Here let's say we are encoding a prompt. In the example SDXLTextEncoderStep this is done in the following way

        (
            data.prompt_embeds,
            data.negative_prompt_embeds,
            data.pooled_prompt_embeds,
            data.negative_pooled_prompt_embeds,
        ) = pipeline.encode_prompt(
            data.prompt,
            data.prompt_2,
            data.device,
            1,
            data.do_classifier_free_guidance,
            data.negative_prompt,
            data.negative_prompt_2,
            prompt_embeds=None,
            negative_prompt_embeds=None,
            pooled_prompt_embeds=None,
            negative_pooled_prompt_embeds=None,
            lora_scale=data.text_encoder_lora_scale,
            clip_skip=data.clip_skip,
        )

The encode_prompt method is defined at the global pipeline level. So if I'm trying to understand the Block I need to hop back and forth between the block and StableDiffusionXLModularPipeline. What if I want to create a custom prompt encoding method for the pipeline? Should I define it inside the block? Or do I have to rewrite StableDiffusionXLModularPipeline with a new method?

Can encode_prompt be defined and executed inside the block itself. In this case, when you read the StableDiffusionXLTextEncoderStep you get a full understanding of what is happening. If you need to access the encoding method from the ModularPipeline instance, you could do something like

my_modular_pipe.pipeline_block['text_encoder_step'].encode_prompt()

I think Modular actually supports this workflow already.

Is it also considered bad practice to set components as attributes in the blocks as use them that way? Something like?

	@torch.no_grad()
	def __call__(self, pipeline, state: PipelineState) -> PipelineState:
		# Get inputs and intermediates
		data = self.get_block_state(state)
		self.check_inputs(pipeline, data)
		prompt_embeds = self.text_encoder(data.prompt)

Regarding Auxillaries, Is there a strong reason to not have these objects just be considered components as well?

Auto Workflow

I am a little apprehensive about introducing Auto workflows in V1. IMO it's better to let users get used to the mechanics of using Modular manually before introducing any "magic". But I will leave this to your discretion.

Modular Pipeline, Block State, Pipeline State

I like these a lot and I'm pretty much aligned on how they work.

One small nit that is unrelated to the actual functionality (just putting out here for consideration) Would prefer that we use block_state instead of data for this variable

	@torch.no_grad()
	def __call__(self, pipeline, state: PipelineState) -> PipelineState:
		# Get inputs and intermediates
		data = self.get_block_state(state)

Obviously the work here is very extensive and I'm still playing around with it. LMK if I've misunderstood some concepts or if I should open PRs to try and clarify any of these points.

DN6 avatar Feb 13 '25 14:02 DN6