`SDXLCFGCutoffCallback` does not work with `StableDiffusionXLControlNetPipeline`
Describe the bug
Running CFGCutoffCallback with ControlNet SDXL will raise following error
diffusers/src/diffusers/models/attention.py:372, in BasicTransformerBlock.forward(self, hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, timestep, cross_attention_kwargs, class_labels, added_cond_kwargs)
364 norm_hidden_states = self.pos_embed(norm_hidden_states)
366 attn_output = self.attn2(
367 norm_hidden_states,
368 encoder_hidden_states=encoder_hidden_states,
369 attention_mask=encoder_attention_mask,
370 **cross_attention_kwargs,
371 )
--> 372 hidden_states = attn_output + hidden_states
374 # 4. Feed-forward
375 # i2vgen doesn't have this norm 🤷♂️
376 if self.norm_type == "ada_norm_continuous":
RuntimeError: The size of tensor a (8192) must match the size of tensor b (4096) at non-singleton dimension 1
which occurs due to conditional image (https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py#L1488) is not converted back to batch 1.
So the solution would be either adding new Callback for ControlNet or fixing current Callback to convert image back to shape 1
Reproduction
from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, AutoencoderKL
from diffusers.callbacks import SDXLCFGCutoffCallback
from diffusers.utils import load_image, make_image_grid
from PIL import Image
import cv2
import numpy as np
import torch
original_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
)
image = np.array(original_image)
low_threshold = 100
high_threshold = 200
image = cv2.Canny(image, low_threshold, high_threshold)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
canny_image = Image.fromarray(image)
controlnet = ControlNetModel.from_pretrained(
"diffusers/controlnet-canny-sdxl-1.0",
torch_dtype=torch.float16,
use_safetensors=True
)
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
controlnet=controlnet,
vae=vae,
torch_dtype=torch.float16,
use_safetensors=True
)
pipe.enable_model_cpu_offload()
prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
negative_prompt = 'low quality, bad quality, sketches'
callback = SDXLCFGCutoffCallback(cutoff_step_ratio=0.4)
image = pipe(
prompt,
negative_prompt=negative_prompt,
image=canny_image,
controlnet_conditioning_scale=0.5,
callback_on_step_end=callback,
).images[0]
make_image_grid([original_image, canny_image, image], rows=1, cols=3)
Logs
No response
System Info
- 🤗 Diffusers version: 0.29.0.dev0
- Platform: Linux-4.18.0-408.el8.x86_64-x86_64-with-glibc2.17
- Running on a notebook?: No
- Running on Google Colab?: No
- Python version: 3.8.13
- PyTorch version (GPU?): 2.1.0+cu121 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.23.4
- Transformers version: 4.40.0.dev0
- Accelerate version: 0.28.0
- PEFT version: 0.11.1
- Bitsandbytes version: 0.43.1
- Safetensors version: 0.4.2
- xFormers version: 0.0.22.post7
- Accelerator: NVIDIA GeForce RTX 4090, 24564 MiB
- Using GPU in script?: Yes
- Using distributed or parallel set-up in script?: No
Who can help?
@sayakpaul @yiyixuxu
Thanks for reporting this issue, I'll look into it, probably the best option here is to make it work with the same callback for all the SDXL related pipelines.
Hi there! I just had this problem. Any update here? Or code that we could use to overwrite the callback meanwhile. Thanks a lot.
by a coincidence I'm just working on this, I'll open a PR soon
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
not stale
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.