[Performance] Issue on *SanaLinearAttnProcessor2_0 family. 1.06X speedup can be reached with a simple change.
Sys env:
OS Ubuntu 22.04 PyTorch 2.4.0+cu121 sana == 0.0.1 Diffusers == 0.34.0.dev0
Reproduce:
Try the demo test code:
import torch
from diffusers import SanaPAGPipeline
pipe = SanaPAGPipeline.from_pretrained(
# "Efficient-Large-Model/Sana_1600M_512px_diffusers",
"Efficient-Large-Model/SANA1.5_1.6B_1024px_diffusers",
torch_dtype=torch.bfloat16,
pag_applied_layers="transformer_blocks.8",
)
pipe.to("cuda")
pipe.text_encoder.to(torch.bfloat16)
pipe.vae.to(torch.bfloat16)
prompt = 'a cyberpunk cat with a neon sign that says "Sana"'
image = pipe(
prompt=prompt,
guidance_scale=5.0,
pag_scale=2.0,
num_inference_steps=20,
generator=torch.Generator(device="cuda").manual_seed(42),
)[0]
image[0].save('sana.png')
Inference data will go through SanaLinearAttnProcessor2_0
Issue Description:
Lines 6042 and 6043 first transposed a contiguous tensor and then did type casting. Type casting invokes a data copy from an old type tensor to a new one. But if you print the new tensor's stride(), you will see:
hidden_states = hidden_states.flatten(1, 2).transpose(1, 2)
hidden_states = hidden_states.to(original_dtype)
print("Contiguity after type casting: ", hidden_states.is_contiguous()) # False
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
The problem is typecasting copies, only did the dtype transmission based on the input tensor's strides. And the bad-strided tensor is immediately used by the latter two functions. Inefficiency is broadcast.
How to Fix:
let hidden_states.to(original_dtype) do contiguous and typecasting simultaneously.
One possible approach:
@torch.compile
def transpose_cast_kernel(input_tensor: torch.Tensor) -> torch.Tensor:
"""
torch-compiled kernel that transposes a 2D tensor and converts it to bfloat16
"""
converted = input_tensor.to(torch.bfloat16)
transposed = torch.transpose(converted, 1, 2).contiguous()
return transposed
Use the versatile operation to handle the creation of the new tensor.
hidden_states = hidden_states.flatten(1, 2).transpose(1, 2)
hidden_states = transpose_cast_kernel(hidden_states)
# hidden_states.is_contiguous() True
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
Or, your expert team could do even better.
Measurement:
By adopting the previous change, the SanaLinearAttnProcessor2_0.call enjoys 1.06X speedup on RTX3090. PAGCFGSanaLinearAttnProcessor2_0, and PAGIdentitySanaLinearAttnProcessor2_0 have similar logic and lose performance as well.
It looks interesting. Any insights @sayakpaul ?
Very nice issue thread and discussion. @yiyixuxu WDYT about this change?
indeed, cc @a-r-r-o-w here can we run some tests, if True, let's be mindful about this pattern across our code base
indeed, cc @a-r-r-o-w here can we run some tests, if True, let's be mindful about this pattern across our code base
Hi yiyixuxu and arrow. Any follow-ups? :)
Hi, sorry I missed the email notification. I'll look into it this weekend or next week, thanks for reminding!
@David-Dingle I tried benchmarking some simple changes to the processor. In practice, the changes are not fast in eager mode (because calling contiguous is slower than just doing the following operations with non-contiguous tensor).
Benchmark
import time
import torch
import triton
import triton.runtime as runtime
# Original with transpose and no contiguous
class SanaLinearAttentionOriginal(torch.nn.Module):
def __init__(self, heads: int, head_dim: int):
super().__init__()
self.heads = heads
self.head_dim = head_dim
self.to_q = torch.nn.Linear(heads * head_dim, heads * head_dim)
self.to_k = torch.nn.Linear(heads * head_dim, heads * head_dim)
self.to_v = torch.nn.Linear(heads * head_dim, heads * head_dim)
self.norm_q = torch.nn.LayerNorm(heads * head_dim)
self.norm_k = torch.nn.LayerNorm(heads * head_dim)
self.to_out = torch.nn.ModuleList([torch.nn.Linear(heads * head_dim, heads * head_dim), torch.nn.Dropout(0.0)])
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
) -> torch.Tensor:
original_dtype = hidden_states.dtype
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
query = self.to_q(hidden_states)
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
query = self.norm_q(query)
key = self.norm_k(key)
query = query.transpose(1, 2).unflatten(1, (self.heads, -1))
key = key.transpose(1, 2).unflatten(1, (self.heads, -1)).transpose(2, 3)
value = value.transpose(1, 2).unflatten(1, (self.heads, -1))
query = torch.relu(query)
key = torch.relu(key)
query, key, value = query.float(), key.float(), value.float()
value = torch.nn.functional.pad(value, (0, 0, 0, 1), mode="constant", value=1.0)
scores = torch.matmul(value, key)
hidden_states = torch.matmul(scores, query)
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + 1e-15)
hidden_states = hidden_states.flatten(1, 2).transpose(1, 2)
hidden_states = hidden_states.to(original_dtype)
hidden_states = self.to_out[0](hidden_states)
hidden_states = self.to_out[1](hidden_states)
return hidden_states
# Tests permute + contiguous before linear layer
class SanaLinearAttentionModified1(torch.nn.Module):
def __init__(self, heads: int, head_dim: int):
super().__init__()
self.heads = heads
self.head_dim = head_dim
self.to_q = torch.nn.Linear(heads * head_dim, heads * head_dim)
self.to_k = torch.nn.Linear(heads * head_dim, heads * head_dim)
self.to_v = torch.nn.Linear(heads * head_dim, heads * head_dim)
self.norm_q = torch.nn.LayerNorm(heads * head_dim)
self.norm_k = torch.nn.LayerNorm(heads * head_dim)
self.to_out = torch.nn.ModuleList([torch.nn.Linear(heads * head_dim, heads * head_dim), torch.nn.Dropout(0.0)])
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
) -> torch.Tensor:
original_dtype = hidden_states.dtype
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
query = self.to_q(hidden_states)
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
query = self.norm_q(query)
key = self.norm_k(key)
query = query.permute(0, 2, 1).unflatten(1, (self.heads, -1))
key = key.permute(0, 2, 1).unflatten(1, (self.heads, -1)).permute(0, 1, 3, 2)
value = value.permute(0, 2, 1).unflatten(1, (self.heads, -1))
query = torch.relu(query)
key = torch.relu(key)
query, key, value = query.float(), key.float(), value.float()
value = torch.nn.functional.pad(value, (0, 0, 0, 1), mode="constant", value=1.0)
scores = torch.matmul(value, key)
hidden_states = torch.matmul(scores, query)
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + 1e-15)
hidden_states = hidden_states.to(original_dtype)
hidden_states = hidden_states.flatten(1, 2).permute(0, 2, 1).contiguous()
hidden_states = self.to_out[0](hidden_states)
hidden_states = self.to_out[1](hidden_states)
return hidden_states
# Tests transpose + contiguous before linear layer
class SanaLinearAttentionModified2(torch.nn.Module):
def __init__(self, heads: int, head_dim: int):
super().__init__()
self.heads = heads
self.head_dim = head_dim
self.to_q = torch.nn.Linear(heads * head_dim, heads * head_dim)
self.to_k = torch.nn.Linear(heads * head_dim, heads * head_dim)
self.to_v = torch.nn.Linear(heads * head_dim, heads * head_dim)
self.norm_q = torch.nn.LayerNorm(heads * head_dim)
self.norm_k = torch.nn.LayerNorm(heads * head_dim)
self.to_out = torch.nn.ModuleList([torch.nn.Linear(heads * head_dim, heads * head_dim), torch.nn.Dropout(0.0)])
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
) -> torch.Tensor:
original_dtype = hidden_states.dtype
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
query = self.to_q(hidden_states)
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
query = self.norm_q(query)
key = self.norm_k(key)
query = query.transpose(1, 2).unflatten(1, (self.heads, -1))
key = key.transpose(1, 2).unflatten(1, (self.heads, -1)).transpose(2, 3)
value = value.transpose(1, 2).unflatten(1, (self.heads, -1))
query = torch.relu(query)
key = torch.relu(key)
query, key, value = query.float(), key.float(), value.float()
value = torch.nn.functional.pad(value, (0, 0, 0, 1), mode="constant", value=1.0)
scores = torch.matmul(value, key)
hidden_states = torch.matmul(scores, query)
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + 1e-15)
hidden_states = hidden_states.to(original_dtype)
hidden_states = hidden_states.flatten(1, 2).transpose(1, 2).contiguous()
hidden_states = self.to_out[0](hidden_states)
hidden_states = self.to_out[1](hidden_states)
return hidden_states
def run_bench(fn, *args, num_warmups: int = 8, num_repeats: int = 32):
output = fn(*args)
torch.cuda.synchronize()
time.sleep(0.5)
cache = runtime.driver.active.get_empty_cache_for_benchmark()
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_repeats)]
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_repeats)]
for _ in range(num_warmups):
runtime.driver.active.clear_cache(cache)
fn(*args)
torch.cuda.synchronize()
time.sleep(0.5)
for i in range(num_repeats):
runtime.driver.active.clear_cache(cache)
start_events[i].record()
fn(*args)
end_events[i].record()
torch.cuda.synchronize()
elapsed_times = [start.elapsed_time(end) for start, end in zip(start_events, end_events)]
mean_time = sum(elapsed_times) / len(elapsed_times)
return output, mean_time
@torch.inference_mode()
def benchmark():
dtype = torch.bfloat16
device = torch.device("cuda")
batch_size = 1
sequence_length = 1024
encoder_sequence_length = 128
heads = 70
head_dim = 32
torch.manual_seed(42)
model1 = SanaLinearAttentionOriginal(heads=heads, head_dim=head_dim)
model2 = SanaLinearAttentionModified1(heads=heads, head_dim=head_dim)
model3 = SanaLinearAttentionModified2(heads=heads, head_dim=head_dim)
model2.load_state_dict(model1.state_dict())
model3.load_state_dict(model1.state_dict())
model1.to(device=device, dtype=dtype)
model2.to(device=device, dtype=dtype)
model3.to(device=device, dtype=dtype)
hidden_states = torch.randn(batch_size, sequence_length, heads * head_dim, device=device, dtype=dtype)
encoder_hidden_states = torch.randn(batch_size, encoder_sequence_length, 20 * 112, device=device, dtype=dtype)
out2, time2 = run_bench(model2, hidden_states, encoder_hidden_states, num_warmups=16, num_repeats=128)
out1, time1 = run_bench(model1, hidden_states, encoder_hidden_states, num_warmups=16, num_repeats=128)
out3, time3 = run_bench(model3, hidden_states, encoder_hidden_states, num_warmups=16, num_repeats=128)
for out in [out2, out3]:
diff = out1 - out
absdiff = diff.abs()
absmax = absdiff.max()
mae = absdiff.mean()
mse = (absdiff ** 2).mean()
print(f"{absmax=:.5f}, {mae=:.5f}, {mse=:.5f}")
print(f"time original : {time1:.5f} ms")
print(f"time modified1: {time2:.5f} ms")
print(f"time modified2: {time3:.5f} ms")
if __name__ == "__main__":
benchmark()
Results
(nightly-venv) aryan@audace:~/work/diffusers$ seq 10 | xargs -Iz python3 dump16.py
absmax=0.00098, mae=0.00004, mse=0.00000
absmax=0.00098, mae=0.00004, mse=0.00000
time original : 0.31323 ms
time modified1: 0.34633 ms
time modified2: 0.34269 ms
absmax=0.00098, mae=0.00004, mse=0.00000
absmax=0.00098, mae=0.00004, mse=0.00000
time original : 0.31361 ms
time modified1: 0.34725 ms
time modified2: 0.34356 ms
absmax=0.00098, mae=0.00004, mse=0.00000
absmax=0.00098, mae=0.00004, mse=0.00000
time original : 0.31337 ms
time modified1: 0.34732 ms
time modified2: 0.34368 ms
absmax=0.00098, mae=0.00004, mse=0.00000
absmax=0.00098, mae=0.00004, mse=0.00000
time original : 0.31323 ms
time modified1: 0.34802 ms
time modified2: 0.34111 ms
absmax=0.00098, mae=0.00004, mse=0.00000
absmax=0.00098, mae=0.00004, mse=0.00000
time original : 0.31368 ms
time modified1: 0.34775 ms
time modified2: 0.34185 ms
absmax=0.00098, mae=0.00004, mse=0.00000
absmax=0.00098, mae=0.00004, mse=0.00000
time original : 0.31312 ms
time modified1: 0.34863 ms
time modified2: 0.34142 ms
absmax=0.00098, mae=0.00004, mse=0.00000
absmax=0.00098, mae=0.00004, mse=0.00000
time original : 0.31376 ms
time modified1: 0.34759 ms
time modified2: 0.34172 ms
absmax=0.00098, mae=0.00004, mse=0.00000
absmax=0.00098, mae=0.00004, mse=0.00000
time original : 0.31427 ms
time modified1: 0.34940 ms
time modified2: 0.34177 ms
absmax=0.00098, mae=0.00004, mse=0.00000
absmax=0.00098, mae=0.00004, mse=0.00000
time original : 0.31456 ms
time modified1: 0.34854 ms
time modified2: 0.34276 ms
absmax=0.00098, mae=0.00004, mse=0.00000
absmax=0.00098, mae=0.00004, mse=0.00000
time original : 0.31516 ms
time modified1: 0.34944 ms
time modified2: 0.34215 ms
I am able to replicate these timings on 3090, 4090 and A100, so it seems like trying to make the tensors contiguous in eager mode is significantly slower. I did test with compiling the dtype + contiguous, which is definitely faster, but we don't add any in-built optimizations or clever tricks (like leveraging existing fused kernel implementations like torch.addcmul) to the library because they don't port to all devices. Maybe something to consider in the future :)
Lurker question here for @a-r-r-o-w : in your benchmark, why are mae and absmax not zero ? How come the different implems do incur very small numerical errors ?
(super interesting discussion btw)
I don't know if this is exactly right, but my best guess is this is just bf16 shenanigans. If you test with float32, the difference would be 0. Changing memory layout means that the underlying multiplies and adds happen in a different order, which is known to result in different values for bf16 due to its smaller precision of just 7 bits (and because floating point operations are not associative). The same happens when using tf32/fp16 format as well
(nightly-venv) aryan@audace:~/work/diffusers$ python3
Python 3.10.16 (main, Feb 8 2025, 10:07:26) [GCC 13.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> x = torch.randn(1024, dtype=torch.bfloat16)
>>> out1 = x + 0.3 - 0.2
>>> out2 = x - 0.2 + 0.3
>>> (out1 - out2).abs().max()
tensor(0.0078, dtype=torch.bfloat16)
(Thx - actually this also happens with float32, but this is getting off-topic)
So the general conclusion of the discussion seems to be that calling contiguous() before any operation would theoretically be faster, but to get the speedup in practice we would need extra stuff (like compilation) you don't want use in the diffusers lib ?
Calling contiguous may or may not be faster even with compilation. It really depends on the surrounding context of what's it being used for. In this case, it's slightly faster.
but to get the speedup in practice we would need extra stuff (like compilation) you don't want use in the diffusers lib ?
Yes true. We generally don't want any compiled functions or other optimizations to exist within the library itself (unless it's a widely followed practice). Any "clever" optimization is intended to be done by the interested user
So the general conclusion of the discussion seems to be that calling contiguous() before any operation would theoretically be faster
Not necessarily, but in many common cases where people force contiguous-ness, yes (if compiling). Sometimes, the contiguous calls can simply be ignored by the torch compiler if it deems that it will be slower (based on whatever heuristics are built into it). I don't have a minimal example on hand, but I've definitely observed this in the past.
Other times, calling contiguous can be a big footgun. In one case, I had a simple triton quantized matmul kernel which assumed A in row-major and B in column-major memory ordering. This kernel was called after another set of operations that already resulted in B being column-major. Not knowing better, I put a contiguous call on B, which forced it to be row-major. However, the kernel launcher performed B.t().contiguous().t() to ensure column-majorness of B. This resulted in 2 extra unnecessary contiguous calls per layer (a total of 4 per layer in my case because of two instances of such operation). The performance impact was ~8% per training step IIRC in eager mode. With torch compile however, it simply removed the contiguous calls :)
I believe this can be marked as resolved :)