diffusers
diffusers copied to clipboard
⚡️ Speed up method `Kandinsky3ConditionalGroupNorm.forward` by 7%
📄 7% (0.07x) speedup for Kandinsky3ConditionalGroupNorm.forward in src/diffusers/models/unets/unet_kandinsky3.py
⏱️ Runtime : 2.16 milliseconds → 2.02 milliseconds (best of 332 runs)
📝 Explanation and details
Certainly! Here are the most important optimizations for this program, based on the line profiling results.
- The main bottleneck is
self.norm(x) * (scale + 1.0) + shiftand theself.context_mlp(context)call. - The loop that repeatedly applies
unsqueezeto the context tensor is inefficient. - You can vectorize context expansion using
.viewor.reshapeto match the desired broadcastable shape all at once, rather than unsqueezing in a loop.
The improved code below removes the loop, performs shape expansion more efficiently, and should provide speedups for larger batch sizes or channel/image sizes.
Summary of Optimizations.
- Removed for-loop with efficient tensor view: The repetitive
unsqueezecalls are replaced with a singleview, which is much faster for matching the broadcasting shape. - Precompute and reuse shapes: Uses
x.dim()to compute required shape for broadcast once, no per-dimension Python looping. - All existing semantics and output shapes preserved.
- No unnecessary temp allocations or autograd-op graph buildup.
This rewrite keeps the function signatures and logic unchanged, but should yield notable performance improvements, especially for large spatial tensors.
✅ Correctness verification report:
| Test | Status |
|---|---|
| ⚙️ Existing Unit Tests | 🔘 None Found |
| 🌀 Generated Regression Tests | ✅ 31 Passed |
| ⏪ Replay Tests | 🔘 None Found |
| 🔎 Concolic Coverage Tests | 🔘 None Found |
| 📊 Tests Coverage | 100.0% |
🌀 Generated Regression Tests Details
import pytest # used for our unit tests
import torch # used for tensor operations
from src.diffusers.models.unets.unet_kandinsky3 import \
Kandinsky3ConditionalGroupNorm
from torch import nn
# unit tests
# ---- Basic Test Cases ----
def test_forward_basic_2d_batch():
# Test with 2D spatial input, batch size 2, 4 channels, 2 groups, context_dim 8
batch, channels, height, width = 2, 4, 8, 8
groups = 2
context_dim = 8
x = torch.randn(batch, channels, height, width)
context = torch.randn(batch, context_dim)
model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
codeflash_output = model.forward(x, context); out = codeflash_output
# Output should be differentiable
out.sum().backward()
def test_forward_basic_1d():
# Test with 1D input (e.g., sequence), batch size 3, 6 channels, 3 groups, context_dim 4
batch, channels, length = 3, 6, 16
groups = 3
context_dim = 4
x = torch.randn(batch, channels, length)
context = torch.randn(batch, context_dim)
model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
codeflash_output = model.forward(x, context); out = codeflash_output
def test_forward_basic_3d():
# Test with 3D input (e.g., video), batch size 1, 8 channels, 2 groups, context_dim 10
batch, channels, d, h, w = 1, 8, 2, 4, 4
groups = 2
context_dim = 10
x = torch.randn(batch, channels, d, h, w)
context = torch.randn(batch, context_dim)
model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
codeflash_output = model.forward(x, context); out = codeflash_output
def test_forward_context_zero_affine():
# Test that with zero-initialized context_mlp, output equals GroupNorm(x)
batch, channels, h, w = 2, 4, 8, 8
groups = 2
context_dim = 7
x = torch.randn(batch, channels, h, w)
context = torch.randn(batch, context_dim)
model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
# Since context_mlp is zero-initialized, scale=0, shift=0, so output=GroupNorm(x)
codeflash_output = model.forward(x, context); out = codeflash_output
baseline = model.norm(x)
# ---- Edge Test Cases ----
def test_forward_single_element_batch():
# Test with batch size 1
batch, channels, h, w = 1, 4, 5, 5
groups = 2
context_dim = 3
x = torch.randn(batch, channels, h, w)
context = torch.randn(batch, context_dim)
model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
codeflash_output = model.forward(x, context); out = codeflash_output
def test_forward_single_channel():
# Test with single channel (groups=1)
batch, channels, h, w = 2, 1, 4, 4
groups = 1
context_dim = 5
x = torch.randn(batch, channels, h, w)
context = torch.randn(batch, context_dim)
model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
codeflash_output = model.forward(x, context); out = codeflash_output
def test_forward_single_spatial():
# Test with single spatial dimension (e.g., length=1)
batch, channels, length = 2, 4, 1
groups = 2
context_dim = 4
x = torch.randn(batch, channels, length)
context = torch.randn(batch, context_dim)
model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
codeflash_output = model.forward(x, context); out = codeflash_output
def test_forward_mismatched_context_dim():
# Test with wrong context_dim (should raise an error)
batch, channels, h, w = 2, 4, 8, 8
groups = 2
context_dim = 7
x = torch.randn(batch, channels, h, w)
wrong_context = torch.randn(batch, context_dim + 1)
model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
with pytest.raises(RuntimeError):
model.forward(x, wrong_context)
def test_forward_mismatched_batch_size():
# Test with mismatched batch size between x and context (should raise error)
batch, channels, h, w = 2, 4, 8, 8
groups = 2
context_dim = 5
x = torch.randn(batch, channels, h, w)
context = torch.randn(batch + 1, context_dim)
model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
with pytest.raises(RuntimeError):
model.forward(x, context)
def test_forward_invalid_groups():
# Test with groups not dividing channels evenly (should raise error)
batch, channels, h, w = 2, 5, 8, 8
groups = 3 # 5 not divisible by 3
context_dim = 4
x = torch.randn(batch, channels, h, w)
context = torch.randn(batch, context_dim)
with pytest.raises(ValueError):
model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
model.forward(x, context)
def test_forward_empty_input():
# Test with empty input tensor (should raise error)
batch, channels, h, w = 0, 4, 8, 8
groups = 2
context_dim = 4
x = torch.randn(batch, channels, h, w)
context = torch.randn(batch, context_dim)
model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
with pytest.raises(Exception):
model.forward(x, context)
def test_forward_nan_inf_input():
# Test with NaN and Inf values in x
batch, channels, h, w = 2, 4, 8, 8
groups = 2
context_dim = 4
x = torch.randn(batch, channels, h, w)
x[0, 0, 0, 0] = float('nan')
x[1, 1, 1, 1] = float('inf')
context = torch.randn(batch, context_dim)
model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
codeflash_output = model.forward(x, context); out = codeflash_output
def test_forward_nan_inf_context():
# Test with NaN and Inf values in context
batch, channels, h, w = 2, 4, 8, 8
groups = 2
context_dim = 4
x = torch.randn(batch, channels, h, w)
context = torch.randn(batch, context_dim)
context[0, 0] = float('nan')
context[1, 1] = float('inf')
model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
codeflash_output = model.forward(x, context); out = codeflash_output
# ---- Large Scale Test Cases ----
def test_forward_large_batch():
# Test with large batch size
batch, channels, h, w = 128, 4, 8, 8
groups = 2
context_dim = 8
x = torch.randn(batch, channels, h, w)
context = torch.randn(batch, context_dim)
model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
codeflash_output = model.forward(x, context); out = codeflash_output
def test_forward_large_channels():
# Test with large number of channels (but less than 100MB)
batch, channels, h, w = 2, 256, 8, 8 # 2*256*8*8*4B = 131072B = 128KB
groups = 16
context_dim = 32
x = torch.randn(batch, channels, h, w)
context = torch.randn(batch, context_dim)
model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
codeflash_output = model.forward(x, context); out = codeflash_output
def test_forward_large_spatial():
# Test with large spatial dimensions (but less than 100MB)
batch, channels, h, w = 2, 8, 64, 64 # 2*8*64*64*4B = 131072B = 2MB
groups = 4
context_dim = 8
x = torch.randn(batch, channels, h, w)
context = torch.randn(batch, context_dim)
model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
codeflash_output = model.forward(x, context); out = codeflash_output
def test_forward_large_3d():
# Test with large 3D input (e.g., volumetric data)
batch, channels, d, h, w = 1, 16, 16, 8, 8 # 1*16*16*8*8*4B = 65536B = 64KB
groups = 4
context_dim = 16
x = torch.randn(batch, channels, d, h, w)
context = torch.randn(batch, context_dim)
model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
codeflash_output = model.forward(x, context); out = codeflash_output
def test_forward_large_context_dim():
# Test with large context dimension
batch, channels, h, w = 2, 8, 8, 8
groups = 4
context_dim = 512
x = torch.randn(batch, channels, h, w)
context = torch.randn(batch, context_dim)
model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
codeflash_output = model.forward(x, context); out = codeflash_output
def test_forward_performance():
# Test that forward pass runs in reasonable time for large input
import time
batch, channels, h, w = 16, 32, 32, 32 # 16*32*32*32*4B = 2MB
groups = 8
context_dim = 32
x = torch.randn(batch, channels, h, w)
context = torch.randn(batch, context_dim)
model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
start = time.time()
codeflash_output = model.forward(x, context); out = codeflash_output
elapsed = time.time() - start
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
import pytest # used for our unit tests
import torch # for tensor operations
from src.diffusers.models.unets.unet_kandinsky3 import \
Kandinsky3ConditionalGroupNorm
from torch import nn
# unit tests
# --------- BASIC TEST CASES ---------
def test_forward_basic_2d():
# Simple 2D input (batch, channels, height, width)
batch, channels, height, width = 2, 4, 8, 8
groups = 2
context_dim = 5
x = torch.randn(batch, channels, height, width)
context = torch.randn(batch, context_dim)
model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
codeflash_output = model.forward(x, context); out = codeflash_output
def test_forward_basic_1d():
# 1D input (batch, channels, length)
batch, channels, length = 3, 6, 10
groups = 3
context_dim = 7
x = torch.randn(batch, channels, length)
context = torch.randn(batch, context_dim)
model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
codeflash_output = model.forward(x, context); out = codeflash_output
def test_forward_basic_3d():
# 3D input (batch, channels, depth, height, width)
batch, channels, d, h, w = 1, 8, 4, 4, 4
groups = 4
context_dim = 3
x = torch.randn(batch, channels, d, h, w)
context = torch.randn(batch, context_dim)
model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
codeflash_output = model.forward(x, context); out = codeflash_output
def test_forward_context_broadcasting():
# Check that context is broadcast correctly for different spatial shapes
batch, channels, h, w = 2, 4, 12, 7
groups = 2
context_dim = 6
x = torch.randn(batch, channels, h, w)
context = torch.randn(batch, context_dim)
model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
codeflash_output = model.forward(x, context); out = codeflash_output
# Output should be different from GroupNorm(x) due to context
gn = nn.GroupNorm(groups, channels, affine=False)
normed = gn(x)
# --------- EDGE TEST CASES ---------
def test_forward_single_batch():
# Single batch
batch, channels, h, w = 1, 4, 5, 5
groups = 2
context_dim = 4
x = torch.randn(batch, channels, h, w)
context = torch.randn(batch, context_dim)
model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
codeflash_output = model.forward(x, context); out = codeflash_output
def test_forward_single_channel():
# Single channel (should raise error for groups > 1)
batch, channels, h, w = 2, 1, 8, 8
groups = 1
context_dim = 3
x = torch.randn(batch, channels, h, w)
context = torch.randn(batch, context_dim)
model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
codeflash_output = model.forward(x, context); out = codeflash_output
def test_forward_minimal_spatial():
# Minimal spatial dimensions (1x1)
batch, channels, h, w = 2, 2, 1, 1
groups = 2
context_dim = 2
x = torch.randn(batch, channels, h, w)
context = torch.randn(batch, context_dim)
model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
codeflash_output = model.forward(x, context); out = codeflash_output
def test_forward_context_wrong_shape():
# Context batch size mismatch
batch, channels, h, w = 2, 4, 4, 4
groups = 2
context_dim = 5
x = torch.randn(batch, channels, h, w)
context = torch.randn(batch+1, context_dim) # Wrong batch size
model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
with pytest.raises(RuntimeError):
codeflash_output = model.forward(x, context); _ = codeflash_output
def test_forward_context_dim_mismatch():
# Context feature dimension mismatch
batch, channels, h, w = 2, 4, 4, 4
groups = 2
context_dim = 5
x = torch.randn(batch, channels, h, w)
context = torch.randn(batch, context_dim+1) # Wrong context_dim
model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
with pytest.raises(RuntimeError):
codeflash_output = model.forward(x, context); _ = codeflash_output
def test_forward_invalid_groups():
# Invalid group number (not dividing channels)
batch, channels, h, w = 2, 5, 4, 4
groups = 2 # 5 not divisible by 2
context_dim = 3
with pytest.raises(ValueError):
_ = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
def test_forward_empty_tensor():
# Empty input tensor
batch, channels, h, w = 0, 4, 4, 4
groups = 2
context_dim = 3
x = torch.randn(batch, channels, h, w)
context = torch.randn(batch, context_dim)
model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
codeflash_output = model.forward(x, context); out = codeflash_output
# --------- LARGE SCALE TEST CASES ---------
def test_forward_large_batch():
# Large batch size, but under 100MB
batch, channels, h, w = 128, 8, 16, 16
groups = 4
context_dim = 16
x = torch.randn(batch, channels, h, w)
context = torch.randn(batch, context_dim)
model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
codeflash_output = model.forward(x, context); out = codeflash_output
def test_forward_large_channels():
# Large number of channels, but under 100MB
batch, channels, h, w = 4, 512, 8, 8
groups = 8
context_dim = 32
x = torch.randn(batch, channels, h, w)
context = torch.randn(batch, context_dim)
model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
codeflash_output = model.forward(x, context); out = codeflash_output
def test_forward_large_spatial():
# Large spatial dimensions, but under 100MB
batch, channels, h, w = 2, 16, 128, 128
groups = 4
context_dim = 8
x = torch.randn(batch, channels, h, w)
context = torch.randn(batch, context_dim)
model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
codeflash_output = model.forward(x, context); out = codeflash_output
def test_forward_large_3d():
# Large 3D input, but under 100MB
batch, channels, d, h, w = 1, 16, 16, 16, 16
groups = 4
context_dim = 8
x = torch.randn(batch, channels, d, h, w)
context = torch.randn(batch, context_dim)
model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
codeflash_output = model.forward(x, context); out = codeflash_output
def test_forward_gradient_flow():
# Check that gradients flow through both x and context
batch, channels, h, w = 2, 4, 8, 8
groups = 2
context_dim = 4
x = torch.randn(batch, channels, h, w, requires_grad=True)
context = torch.randn(batch, context_dim, requires_grad=True)
model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
codeflash_output = model.forward(x, context); out = codeflash_output
loss = out.sum()
loss.backward()
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
To edit these changes git checkout codeflash/optimize-Kandinsky3ConditionalGroupNorm.forward-mb5lqa87 and push.