perf: optimize CausalConv3d for wan autoencoders
What does this PR do?
By optimizing CausalConv3d, this patch improves the overall performance of wan autoencoders by 5-10%.
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [x] Did you read the contributor guideline?
- [x] Did you read our philosophy doc (important for complex PRs)?
- [ ] Was this discussed/approved via a GitHub issue or the forum? Please add a link to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
- [x] Did you write any new necessary tests?
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
testing script:
import torch
import torch.nn as nn
import torch.nn.functional as F
import triton.testing
class CausalConv3d_A(nn.Conv3d):
"""
Implementation A: Fully explicit padding using F.pad
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._padding = (
self.padding[2],
self.padding[2],
self.padding[1],
self.padding[1],
2 * self.padding[0],
0,
)
self.padding = (0, 0, 0) # Reset internal padding to 0
def forward(self, x, cache_x=None):
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
x = F.pad(x, padding)
return super().forward(x)
class CausalConv3d_B(nn.Conv3d):
"""
Implementation B: Explicit Temporal padding, Implicit Spatial padding
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.temporal_padding = 2 * self.padding[0]
# Keep spatial padding, remove temporal padding from conv layer
self.padding = (0, self.padding[1], self.padding[2])
def forward(self, x, cache_x=None):
b, c, t, h, w = x.size()
padding = self.temporal_padding
if cache_x is not None and self.temporal_padding > 0:
cache_x = cache_x.to(x.device)
x = torch.cat([cache_x, x], dim=2)
padding -= cache_x.shape[2]
# Manually pad time dimension
if padding > 0:
x = torch.cat([x.new_zeros(b, c, padding, h, w), x], dim=2)
return super().forward(x)
def setup_models(in_channels, out_channels, kernel_size, padding):
device = "cuda" if torch.cuda.is_available() else "cpu"
model_a = CausalConv3d_A(in_channels, out_channels, kernel_size, padding=padding).to(device)
model_b = CausalConv3d_B(in_channels, out_channels, kernel_size, padding=padding).to(device)
model_b.load_state_dict(model_a.state_dict())
return model_a, model_b
def test_correctness():
print("\n=== Running Correctness Test ===")
B, C, T, H, W = 2, 32, 16, 64, 64
out_C = 64
kernel = 3
pad_val = 1 # resulting in causal pad of 2*1=2
model_a, model_b = setup_models(C, out_C, kernel, padding=(pad_val, pad_val, pad_val))
model_a.eval()
model_b.eval()
x = torch.randn(B, C, T, H, W, device="cuda")
# 1. Test without cache
with torch.no_grad():
out_a = model_a(x)
out_b = model_b(x)
try:
torch.testing.assert_close(out_a, out_b, rtol=1e-5, atol=1e-5)
print("[Pass] Outputs are numerically identical (No Cache).")
except AssertionError as e:
print("[Fail] Outputs differ!")
print(e)
return
cache = torch.randn(B, C, 2, H, W, device="cuda")
with torch.no_grad():
out_a_cache = model_a(x, cache_x=cache)
out_b_cache = model_b(x, cache_x=cache)
try:
torch.testing.assert_close(out_a_cache, out_b_cache, rtol=1e-5, atol=1e-5)
print("[Pass] Outputs are numerically identical (With Cache).")
except AssertionError as e:
print("[Fail] Outputs differ with cache!")
def benchmark_performance():
print("\n=== Running Performance Benchmark ===")
if not torch.cuda.is_available():
print("Skipping benchmark (CUDA not available)")
return
B, C, T, H, W = 4, 64, 32, 128, 128
out_C = 64
kernel = 3
# Padding set to (1,1,1), so T gets padded by 2, H/W by 1
model_a, model_b = setup_models(C, out_C, kernel, padding=(1, 1, 1))
x = torch.randn(B, C, T, H, W, device="cuda")
def run_a():
return model_a(x)
def run_b():
return model_b(x)
ms_a = triton.testing.do_bench(run_a, rep=100)
ms_b = triton.testing.do_bench(run_b, rep=100)
print(f"Implementation A (F.pad): {ms_a:.3f} ms")
print(f"Implementation B (Impl.H/W): {ms_b:.3f} ms")
diff = (ms_a - ms_b) / ms_a * 100
print(f"Implementation B is {diff:.2f}% faster")
if __name__ == "__main__":
test_correctness()
benchmark_performance()
result:
=== Running Correctness Test ===
[Pass] Outputs are numerically identical (No Cache).
[Pass] Outputs are numerically identical (With Cache).
=== Running Performance Benchmark ===
Implementation A (F.pad): 44.787 ms
Implementation B (Impl.H/W): 42.507 ms
Implementation B is 5.09% faster
@sayakpaul @yiyixuxu @DN6 Please take a look, thanks!
@c8ef could we also check if the performance further improves with torch.compile?
@c8ef could we also check if the performance further improves with
torch.compile?
After compilation, the two CausalConv modules have similar performance.
However, I believe this optimization is still useful in scenarios where we cannot compile - for example, when compiling the entire VAE takes too much startup time, or when compiling certain modules may negatively impact video quality.
Any reason for closing?
Any reason for closing?
Sorry, spotty connection and my phone glitched when I touched the screen...
I'm wondering how we can move this forward. From your reviewers' perspective, is this patch worth implementing?
Gentle ping.