diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

non power of 2 sized images with stable cascade fail

Open Teriks opened this issue 1 year ago • 3 comments

Describe the bug

I am not sure if this is a bug, but the documentation does not mention a limitation to power of 2 that I can find.

Generating a non power of 2 sized image fails.

Reproduction

import torch
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline

device = "cuda"
num_images_per_prompt = 1

prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", variant='bf16',
                                                   torch_dtype=torch.bfloat16)
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", variant='bf16',
                                                       torch_dtype=torch.float16)

prompt = "Anthropomorphic cat dressed as a pilot"
negative_prompt = ""

prior.enable_model_cpu_offload()
decoder.enable_model_cpu_offload()

prior_output = prior(
    prompt=prompt,
    height=1024,
    width=680,
    negative_prompt=negative_prompt,
    guidance_scale=4.0,
    num_images_per_prompt=num_images_per_prompt,
    num_inference_steps=20
)

decoder_output = decoder(
    image_embeddings=prior_output.image_embeddings.half(),
    prompt=prompt,
    negative_prompt=negative_prompt,
    guidance_scale=0.0,
    output_type="pil",
    num_inference_steps=10
).images

Logs

Traceback (most recent call last):
  File "REDACT\test.py", line 28, in <module>
    decoder_output = decoder(
                     ^^^^^^^^
  File "REDACT\venv\Lib\site-packages\torch\utils\_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "REDACT\venv\Lib\site-packages\diffusers\pipelines\stable_cascade\pipeline_stable_cascade.py", line 443, in __call__
    predicted_latents = self.decoder(
                        ^^^^^^^^^^^^^
  File "REDACT\venv\Lib\site-packages\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "REDACT\venv\Lib\site-packages\torch\nn\modules\module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "REDACT\venv\Lib\site-packages\accelerate\hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "REDACT\venv\Lib\site-packages\diffusers\models\unets\unet_stable_cascade.py", line 605, in forward
    x = self._up_decode(level_outputs, timestep_ratio_embed, clip)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "REDACT\venv\Lib\site-packages\diffusers\models\unets\unet_stable_cascade.py", line 553, in _up_decode
    x = block(x, skip)
        ^^^^^^^^^^^^^^
  File "REDACT\venv\Lib\site-packages\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "REDACT\venv\Lib\site-packages\torch\nn\modules\module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "REDACT\venv\Lib\site-packages\diffusers\models\unets\unet_stable_cascade.py", line 74, in forward
    x = self.norm(self.depthwise(x))
                  ^^^^^^^^^^^^^^^^^
  File "REDACT\venv\Lib\site-packages\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "REDACT\venv\Lib\site-packages\torch\nn\modules\module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "REDACT\venv\Lib\site-packages\torch\nn\modules\conv.py", line 460, in forward
    return self._conv_forward(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "REDACT\venv\Lib\site-packages\torch\nn\modules\conv.py", line 456, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Input type (float) and bias type (struct c10::Half) should be the same

Process finished with exit code 1

System Info

diffusers 0.27.2 torch 2.2.2+cu121

Who can help?

No response

Teriks avatar Apr 12 '24 02:04 Teriks

Width = 720 results in a different exception

Traceback (most recent call last):
  File "REDACT\test.py", line 30, in <module>
    decoder_output = decoder(
                     ^^^^^^^^
  File "REDACT\venv\Lib\site-packages\torch\utils\_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "REDACT\venv\Lib\site-packages\diffusers\pipelines\stable_cascade\pipeline_stable_cascade.py", line 443, in __call__
    predicted_latents = self.decoder(
                        ^^^^^^^^^^^^^
  File "REDACT\venv\Lib\site-packages\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "REDACT\venv\Lib\site-packages\torch\nn\modules\module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "REDACT\venv\Lib\site-packages\accelerate\hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "REDACT\venv\Lib\site-packages\diffusers\models\unets\unet_stable_cascade.py", line 595, in forward
    x = self.embedding(sample)
        ^^^^^^^^^^^^^^^^^^^^^^
  File "REDACT\venv\Lib\site-packages\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "REDACT\venv\Lib\site-packages\torch\nn\modules\module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "REDACT\venv\Lib\site-packages\torch\nn\modules\container.py", line 217, in forward
    input = module(input)
            ^^^^^^^^^^^^^
  File "REDACT\venv\Lib\site-packages\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "REDACT\venv\Lib\site-packages\torch\nn\modules\module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "REDACT\venv\Lib\site-packages\torch\nn\modules\pixelshuffle.py", line 110, in forward
    return F.pixel_unshuffle(input, self.downscale_factor)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: pixel_unshuffle expects width to be divisible by downscale_factor, but input.size(-1)=181 is not divisible by 2

Process finished with exit code 1

Teriks avatar Apr 12 '24 02:04 Teriks

Hi @Teriks The patch size for the decoder is configured to 2 and the embedding layer is configured based on the patch size, which is what seems to be causing the issue.
https://github.com/huggingface/diffusers/blob/1c000d46e1c821d7bcc267952475373981d0feea/src/diffusers/models/unets/unet_stable_cascade.py#L284

On the surface it looks like we could decouple these, although I'm unsure if it will affect the generated output. cc: @kashif

DN6 avatar Apr 15 '24 03:04 DN6

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar May 12 '24 15:05 github-actions[bot]

as of 0.29.2 it seems that image alignments of 128 are possible.

Teriks avatar Jul 05 '24 02:07 Teriks

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Sep 14 '24 15:09 github-actions[bot]