Fix: Remove hardcoded CUDA autocast in Kandinsky 5 to fix import warning
What does this PR do?
Fixes #12809 This PR fixes it by:
- Removing the @torch.autocast decorator (Fixes the import warning).
- Explicitly casting inputs to float32 inside the forward method (Preserves the required numerical stability).
- Casting the result back to weight.dtype before passing it to the Linear layers (Fixes the dtype mismatch crash).
Verification
I verified that the results remain stable before and after this change by generating images with a fixed seed (generator=torch.manual_seed(42)).
The results are almost the same with some minor differences.
| Before Fix | After Fix |
|---|---|
Reproduction Script
import torch
from diffusers import Kandinsky5T2IPipeline
model_id = "kandinskylab/Kandinsky-5.0-T2I-Lite-sft-Diffusers"
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
pipe = Kandinsky5T2IPipeline.from_pretrained(model_id, torch_dtype=dtype)
pipe.to(device)
seed = 42
generator = torch.Generator(device=device).manual_seed(seed)
print("Generating image...")
output = pipe(
prompt="A cat and a dog baking a cake together in a kitchen.",
negative_prompt="",
num_inference_steps=25, # Reduced for faster verification
guidance_scale=3.5,
height=1024,
width=1024,
generator=generator,
)
image = output.image[0]
image.save("kandinsky_after_fix.png")
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [x] Did you read the contributor guideline?
- [x] Did you read our philosophy doc (important for complex PRs)?
- [x] Was this discussed/approved via a GitHub issue or the forum? Please add a link to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
- [x] Did you write any new necessary tests?
Who can review?
@yiyixuxu @leffff Anyone in the community is free to review the PR once the tests have passed.
Looks good to me!
Thanks for the quick fix! I didn't have time to submit a PR myself, so I really appreciate you jumping on this. 🙏 @adi776borate
@yiyixuxu @sayakpaul A gentle ping to review
Thank you! Could you also provide your testing script?
The verification script is already provided in the PR description above. If you want to test minimally, we can just do:
from diffusers.models.transformers import transformer_kandinsky
print("Import successful.")
Should print a UserWarning on main, but not on this branch.
@bot /style
Style bot fixed some files and pushed the changes.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
This is incorrect.
Minimal reproduction
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def get_freqs(dim, max_period=10000.0):
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=dim, dtype=torch.float32)
/ dim
)
return freqs
class Kandinsky5TimeEmbeddings(nn.Module):
def __init__(self, model_dim, time_dim, max_period=10000.0):
super().__init__()
assert model_dim % 2 == 0
self.model_dim = model_dim
self.max_period = max_period
self.freqs = get_freqs(self.model_dim // 2, self.max_period)
self.in_layer = nn.Linear(model_dim, time_dim, bias=True)
self.activation = nn.SiLU()
self.out_layer = nn.Linear(time_dim, time_dim, bias=True)
@torch.autocast(device_type="cuda", dtype=torch.float32)
def forward(self, time):
args = torch.outer(time, self.freqs.to(device=time.device))
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
return time_embed
class Kandinsky5TimeEmbeddingsPR(nn.Module):
def __init__(self, model_dim, time_dim, max_period=10000.0):
super().__init__()
assert model_dim % 2 == 0
self.model_dim = model_dim
self.max_period = max_period
self.freqs = get_freqs(self.model_dim // 2, self.max_period)
self.in_layer = nn.Linear(model_dim, time_dim, bias=True)
self.activation = nn.SiLU()
self.out_layer = nn.Linear(time_dim, time_dim, bias=True)
def forward(self, time):
time = time.to(dtype=torch.float32)
freqs = self.freqs.to(device=time.device, dtype=torch.float32)
args = torch.outer(time, freqs)
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
time_embed = time_embed.to(dtype=self.in_layer.weight.dtype)
time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
return time_embed
class Kandinsky5TimeEmbeddingsNoAutocast(nn.Module):
def __init__(self, model_dim, time_dim, max_period=10000.0):
super().__init__()
assert model_dim % 2 == 0
self.model_dim = model_dim
self.max_period = max_period
self.freqs = get_freqs(self.model_dim // 2, self.max_period)
self.in_layer = nn.Linear(model_dim, time_dim, bias=True)
self.activation = nn.SiLU()
self.out_layer = nn.Linear(time_dim, time_dim, bias=True)
def forward(self, time):
args = torch.outer(time.to(torch.float32), self.freqs.to(device=time.device))
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
time_embed = F.linear(
self.activation(
F.linear(
time_embed,
self.in_layer.weight.to(torch.float32),
self.in_layer.bias.to(torch.float32),
)
),
self.out_layer.weight.to(torch.float32),
self.out_layer.bias.to(torch.float32),
)
return time_embed
torch.manual_seed(0)
with_autocast = (
Kandinsky5TimeEmbeddings(model_dim=2560, time_dim=512)
.to("cuda", torch.bfloat16)
.eval()
)
torch.manual_seed(0)
pr = (
Kandinsky5TimeEmbeddingsPR(model_dim=2560, time_dim=512)
.to("cuda", torch.bfloat16)
.eval()
)
torch.manual_seed(0)
no_autocast = (
Kandinsky5TimeEmbeddingsNoAutocast(model_dim=2560, time_dim=512)
.to("cuda", torch.bfloat16)
.eval()
)
with torch.no_grad():
time = torch.tensor([952.0]).to("cuda", torch.bfloat16)
with_out = with_autocast(time.clone())
pr_out = pr(time.clone())
no_out = no_autocast(time.clone())
print(f"{with_out.dtype=}, {pr_out.dtype=}, {no_out.dtype=}")
try:
print(f"{torch.allclose(with_out, pr_out)=}")
except RuntimeError as e:
print(f"{e}, casting")
print(f"{torch.allclose(with_out.to(pr_out.dtype), pr_out)=}")
print(f"{torch.allclose(with_out, no_out)=}")
with_out.dtype=torch.float32, pr_out.dtype=torch.bfloat16, no_out.dtype=torch.float32
Float did not match BFloat16, casting
torch.allclose(with_out.to(pr_out.dtype), pr_out)=False
torch.allclose(with_out, no_out)=True
As we see from the minimal reproduction of Kandinsky5TimeEmbeddings, the output from this PR does not match the output from main.
@torch.autocast(device_type="cuda", dtype=torch.float32) means everything is cast to float32, the Linear layers and activation also run in float32 and the output from forward is float32.
In this PR the Linear layers and activation are running in bfloat16, which results in different output from the module and in turn different output image.
Kandinsky5TimeEmbeddings should be:
class Kandinsky5TimeEmbeddingsNoAutocast(nn.Module):
def __init__(self, model_dim, time_dim, max_period=10000.0):
super().__init__()
assert model_dim % 2 == 0
self.model_dim = model_dim
self.max_period = max_period
self.freqs = get_freqs(self.model_dim // 2, self.max_period)
self.in_layer = nn.Linear(model_dim, time_dim, bias=True)
self.activation = nn.SiLU()
self.out_layer = nn.Linear(time_dim, time_dim, bias=True)
def forward(self, time):
args = torch.outer(time.to(torch.float32), self.freqs.to(device=time.device))
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
time_embed = F.linear(
self.activation(
F.linear(
time_embed,
self.in_layer.weight.to(torch.float32),
self.in_layer.bias.to(torch.float32),
)
),
self.out_layer.weight.to(torch.float32),
self.out_layer.bias.to(torch.float32),
)
return time_embed
and Kandinsky5Modulation:
class Kandinsky5Modulation(nn.Module):
def __init__(self, time_dim, model_dim, num_params):
super().__init__()
self.activation = nn.SiLU()
self.out_layer = nn.Linear(time_dim, num_params * model_dim)
self.out_layer.weight.data.zero_()
self.out_layer.bias.data.zero_()
def forward(self, x):
return F.linear(
self.activation(x.to(torch.float32)),
self.out_layer.weight.to(torch.float32),
self.out_layer.bias.to(torch.float32),
)
With those changes, image output matches exactly:
| Main | Fix |
|---|---|
PS C:\Users\user\Downloads> certutil -hashfile kandinsky_after_fix.png SHA256
SHA256 hash of kandinsky_after_fix.png:
3fb7319edc17983593d2a1abc0b5ffed418700f5f7f70d450aefd1e225b52143
CertUtil: -hashfile command completed successfully.
PS C:\Users\user\Downloads> certutil -hashfile kandinsky_before_fix.png SHA256
SHA256 hash of kandinsky_before_fix.png:
3fb7319edc17983593d2a1abc0b5ffed418700f5f7f70d450aefd1e225b52143
CertUtil: -hashfile command completed successfully.
Perhaps the changes could be slightly simplified by making use of _keep_in_fp32_modules so we wouldn't need to cast the weights, but we would still need to cast everything else.
Thanks for the detailed analysis and script @hlky! You are right.
I misunderstood the original author's intent. I assumed they only wanted to protect specific operations (like sin/cos) from overflow.
Regarding _keep_in_fp32_modules: I prefer the manual F.linear approach because it allows us to keep the weights stored in bfloat16/float16 (saving VRAM) and only cast them on-the-fly, whereas _keep_in_fp32_modules would force them to be stored in FP32 permanently.
I'll update the PR with your suggested fix. Thanks again!