diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

Fix: Remove hardcoded CUDA autocast in Kandinsky 5 to fix import warning

Open adi776borate opened this issue 3 weeks ago • 9 comments

What does this PR do?

Fixes #12809 This PR fixes it by:

  1. Removing the @torch.autocast decorator (Fixes the import warning).
  2. Explicitly casting inputs to float32 inside the forward method (Preserves the required numerical stability).
  3. Casting the result back to weight.dtype before passing it to the Linear layers (Fixes the dtype mismatch crash).

Verification

I verified that the results remain stable before and after this change by generating images with a fixed seed (generator=torch.manual_seed(42)).

The results are almost the same with some minor differences.

Before Fix After Fix
kandinsky_before_fix kandinsky_after_fix
Reproduction Script
import torch
from diffusers import Kandinsky5T2IPipeline

model_id = "kandinskylab/Kandinsky-5.0-T2I-Lite-sft-Diffusers"
device = "cuda" if torch.cuda.is_available() else "cpu"

dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
pipe = Kandinsky5T2IPipeline.from_pretrained(model_id, torch_dtype=dtype)
pipe.to(device)

seed = 42
generator = torch.Generator(device=device).manual_seed(seed)

print("Generating image...")
output = pipe(
    prompt="A cat and a dog baking a cake together in a kitchen.",
    negative_prompt="",
    num_inference_steps=25, # Reduced for faster verification
    guidance_scale=3.5,
    height=1024,
    width=1024,
    generator=generator, 
)

image = output.image[0]
image.save("kandinsky_after_fix.png")

Before submitting

  • [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [x] Did you read the contributor guideline?
  • [x] Did you read our philosophy doc (important for complex PRs)?
  • [x] Was this discussed/approved via a GitHub issue or the forum? Please add a link to it if that's the case.
  • [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
  • [x] Did you write any new necessary tests?

Who can review?

@yiyixuxu @leffff Anyone in the community is free to review the PR once the tests have passed.

adi776borate avatar Dec 09 '25 15:12 adi776borate

Looks good to me!

leffff avatar Dec 09 '25 18:12 leffff

Thanks for the quick fix! I didn't have time to submit a PR myself, so I really appreciate you jumping on this. 🙏 @adi776borate

knd0331 avatar Dec 10 '25 00:12 knd0331

@yiyixuxu @sayakpaul A gentle ping to review

adi776borate avatar Dec 11 '25 10:12 adi776borate

Thank you! Could you also provide your testing script?

The verification script is already provided in the PR description above. If you want to test minimally, we can just do:

from diffusers.models.transformers import transformer_kandinsky
print("Import successful.")

Should print a UserWarning on main, but not on this branch.

adi776borate avatar Dec 11 '25 11:12 adi776borate

@bot /style

yiyixuxu avatar Dec 11 '25 15:12 yiyixuxu

Style bot fixed some files and pushed the changes.

github-actions[bot] avatar Dec 11 '25 15:12 github-actions[bot]

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

This is incorrect.

Minimal reproduction

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


def get_freqs(dim, max_period=10000.0):
    freqs = torch.exp(
        -math.log(max_period)
        * torch.arange(start=0, end=dim, dtype=torch.float32)
        / dim
    )
    return freqs


class Kandinsky5TimeEmbeddings(nn.Module):
    def __init__(self, model_dim, time_dim, max_period=10000.0):
        super().__init__()
        assert model_dim % 2 == 0
        self.model_dim = model_dim
        self.max_period = max_period
        self.freqs = get_freqs(self.model_dim // 2, self.max_period)
        self.in_layer = nn.Linear(model_dim, time_dim, bias=True)
        self.activation = nn.SiLU()
        self.out_layer = nn.Linear(time_dim, time_dim, bias=True)

    @torch.autocast(device_type="cuda", dtype=torch.float32)
    def forward(self, time):
        args = torch.outer(time, self.freqs.to(device=time.device))
        time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
        return time_embed


class Kandinsky5TimeEmbeddingsPR(nn.Module):
    def __init__(self, model_dim, time_dim, max_period=10000.0):
        super().__init__()
        assert model_dim % 2 == 0
        self.model_dim = model_dim
        self.max_period = max_period
        self.freqs = get_freqs(self.model_dim // 2, self.max_period)
        self.in_layer = nn.Linear(model_dim, time_dim, bias=True)
        self.activation = nn.SiLU()
        self.out_layer = nn.Linear(time_dim, time_dim, bias=True)

    def forward(self, time):
        time = time.to(dtype=torch.float32)
        freqs = self.freqs.to(device=time.device, dtype=torch.float32)
        args = torch.outer(time, freqs)
        time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        time_embed = time_embed.to(dtype=self.in_layer.weight.dtype)
        time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
        return time_embed


class Kandinsky5TimeEmbeddingsNoAutocast(nn.Module):
    def __init__(self, model_dim, time_dim, max_period=10000.0):
        super().__init__()
        assert model_dim % 2 == 0
        self.model_dim = model_dim
        self.max_period = max_period
        self.freqs = get_freqs(self.model_dim // 2, self.max_period)
        self.in_layer = nn.Linear(model_dim, time_dim, bias=True)
        self.activation = nn.SiLU()
        self.out_layer = nn.Linear(time_dim, time_dim, bias=True)

    def forward(self, time):
        args = torch.outer(time.to(torch.float32), self.freqs.to(device=time.device))
        time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        time_embed = F.linear(
            self.activation(
                F.linear(
                    time_embed,
                    self.in_layer.weight.to(torch.float32),
                    self.in_layer.bias.to(torch.float32),
                )
            ),
            self.out_layer.weight.to(torch.float32),
            self.out_layer.bias.to(torch.float32),
        )
        return time_embed


torch.manual_seed(0)
with_autocast = (
    Kandinsky5TimeEmbeddings(model_dim=2560, time_dim=512)
    .to("cuda", torch.bfloat16)
    .eval()
)
torch.manual_seed(0)
pr = (
    Kandinsky5TimeEmbeddingsPR(model_dim=2560, time_dim=512)
    .to("cuda", torch.bfloat16)
    .eval()
)
torch.manual_seed(0)
no_autocast = (
    Kandinsky5TimeEmbeddingsNoAutocast(model_dim=2560, time_dim=512)
    .to("cuda", torch.bfloat16)
    .eval()
)


with torch.no_grad():
    time = torch.tensor([952.0]).to("cuda", torch.bfloat16)
    with_out = with_autocast(time.clone())
    pr_out = pr(time.clone())
    no_out = no_autocast(time.clone())

print(f"{with_out.dtype=}, {pr_out.dtype=}, {no_out.dtype=}")
try:
    print(f"{torch.allclose(with_out, pr_out)=}")
except RuntimeError as e:
    print(f"{e}, casting")
    print(f"{torch.allclose(with_out.to(pr_out.dtype), pr_out)=}")

print(f"{torch.allclose(with_out, no_out)=}")

with_out.dtype=torch.float32, pr_out.dtype=torch.bfloat16, no_out.dtype=torch.float32
Float did not match BFloat16, casting
torch.allclose(with_out.to(pr_out.dtype), pr_out)=False
torch.allclose(with_out, no_out)=True

As we see from the minimal reproduction of Kandinsky5TimeEmbeddings, the output from this PR does not match the output from main.

@torch.autocast(device_type="cuda", dtype=torch.float32) means everything is cast to float32, the Linear layers and activation also run in float32 and the output from forward is float32.

In this PR the Linear layers and activation are running in bfloat16, which results in different output from the module and in turn different output image.

Kandinsky5TimeEmbeddings should be:


class Kandinsky5TimeEmbeddingsNoAutocast(nn.Module):
    def __init__(self, model_dim, time_dim, max_period=10000.0):
        super().__init__()
        assert model_dim % 2 == 0
        self.model_dim = model_dim
        self.max_period = max_period
        self.freqs = get_freqs(self.model_dim // 2, self.max_period)
        self.in_layer = nn.Linear(model_dim, time_dim, bias=True)
        self.activation = nn.SiLU()
        self.out_layer = nn.Linear(time_dim, time_dim, bias=True)

    def forward(self, time):
        args = torch.outer(time.to(torch.float32), self.freqs.to(device=time.device))
        time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        time_embed = F.linear(
            self.activation(
                F.linear(
                    time_embed,
                    self.in_layer.weight.to(torch.float32),
                    self.in_layer.bias.to(torch.float32),
                )
            ),
            self.out_layer.weight.to(torch.float32),
            self.out_layer.bias.to(torch.float32),
        )
        return time_embed

and Kandinsky5Modulation:


class Kandinsky5Modulation(nn.Module):
    def __init__(self, time_dim, model_dim, num_params):
        super().__init__()
        self.activation = nn.SiLU()
        self.out_layer = nn.Linear(time_dim, num_params * model_dim)
        self.out_layer.weight.data.zero_()
        self.out_layer.bias.data.zero_()

    def forward(self, x):
        return F.linear(
            self.activation(x.to(torch.float32)),
            self.out_layer.weight.to(torch.float32),
            self.out_layer.bias.to(torch.float32),
        )

With those changes, image output matches exactly:

Main Fix
kandinsky_before_fix kandinsky_after_fix
PS C:\Users\user\Downloads> certutil -hashfile kandinsky_after_fix.png SHA256
SHA256 hash of kandinsky_after_fix.png:
3fb7319edc17983593d2a1abc0b5ffed418700f5f7f70d450aefd1e225b52143
CertUtil: -hashfile command completed successfully.
PS C:\Users\user\Downloads> certutil -hashfile kandinsky_before_fix.png SHA256
SHA256 hash of kandinsky_before_fix.png:
3fb7319edc17983593d2a1abc0b5ffed418700f5f7f70d450aefd1e225b52143
CertUtil: -hashfile command completed successfully.

Perhaps the changes could be slightly simplified by making use of _keep_in_fp32_modules so we wouldn't need to cast the weights, but we would still need to cast everything else.

hlky avatar Dec 13 '25 20:12 hlky

Thanks for the detailed analysis and script @hlky! You are right.

I misunderstood the original author's intent. I assumed they only wanted to protect specific operations (like sin/cos) from overflow. Regarding _keep_in_fp32_modules: I prefer the manual F.linear approach because it allows us to keep the weights stored in bfloat16/float16 (saving VRAM) and only cast them on-the-fly, whereas _keep_in_fp32_modules would force them to be stored in FP32 permanently.

I'll update the PR with your suggested fix. Thanks again!

adi776borate avatar Dec 14 '25 05:12 adi776borate