diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

perf: optimize CausalConv3d for wan autoencoders

Open c8ef opened this issue 1 month ago • 7 comments

What does this PR do?

By optimizing CausalConv3d, this patch improves the overall performance of wan autoencoders by 5-10%.

Before submitting

  • [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [x] Did you read the contributor guideline?
  • [x] Did you read our philosophy doc (important for complex PRs)?
  • [ ] Was this discussed/approved via a GitHub issue or the forum? Please add a link to it if that's the case.
  • [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
  • [x] Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

c8ef avatar Dec 06 '25 07:12 c8ef

testing script:

import torch
import torch.nn as nn
import torch.nn.functional as F

import triton.testing


class CausalConv3d_A(nn.Conv3d):
    """
    Implementation A: Fully explicit padding using F.pad
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._padding = (
            self.padding[2],
            self.padding[2],
            self.padding[1],
            self.padding[1],
            2 * self.padding[0],
            0,
        )
        self.padding = (0, 0, 0)  # Reset internal padding to 0

    def forward(self, x, cache_x=None):
        padding = list(self._padding)
        if cache_x is not None and self._padding[4] > 0:
            cache_x = cache_x.to(x.device)
            x = torch.cat([cache_x, x], dim=2)
            padding[4] -= cache_x.shape[2]
        x = F.pad(x, padding)
        return super().forward(x)


class CausalConv3d_B(nn.Conv3d):
    """
    Implementation B: Explicit Temporal padding, Implicit Spatial padding
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.temporal_padding = 2 * self.padding[0]
        # Keep spatial padding, remove temporal padding from conv layer
        self.padding = (0, self.padding[1], self.padding[2])

    def forward(self, x, cache_x=None):
        b, c, t, h, w = x.size()
        padding = self.temporal_padding
        if cache_x is not None and self.temporal_padding > 0:
            cache_x = cache_x.to(x.device)
            x = torch.cat([cache_x, x], dim=2)
            padding -= cache_x.shape[2]

        # Manually pad time dimension
        if padding > 0:
            x = torch.cat([x.new_zeros(b, c, padding, h, w), x], dim=2)
        return super().forward(x)


def setup_models(in_channels, out_channels, kernel_size, padding):
    device = "cuda" if torch.cuda.is_available() else "cpu"

    model_a = CausalConv3d_A(in_channels, out_channels, kernel_size, padding=padding).to(device)
    model_b = CausalConv3d_B(in_channels, out_channels, kernel_size, padding=padding).to(device)

    model_b.load_state_dict(model_a.state_dict())

    return model_a, model_b


def test_correctness():
    print("\n=== Running Correctness Test ===")
    B, C, T, H, W = 2, 32, 16, 64, 64
    out_C = 64
    kernel = 3
    pad_val = 1  # resulting in causal pad of 2*1=2

    model_a, model_b = setup_models(C, out_C, kernel, padding=(pad_val, pad_val, pad_val))
    model_a.eval()
    model_b.eval()

    x = torch.randn(B, C, T, H, W, device="cuda")

    # 1. Test without cache
    with torch.no_grad():
        out_a = model_a(x)
        out_b = model_b(x)

    try:
        torch.testing.assert_close(out_a, out_b, rtol=1e-5, atol=1e-5)
        print("[Pass] Outputs are numerically identical (No Cache).")
    except AssertionError as e:
        print("[Fail] Outputs differ!")
        print(e)
        return

    cache = torch.randn(B, C, 2, H, W, device="cuda")
    with torch.no_grad():
        out_a_cache = model_a(x, cache_x=cache)
        out_b_cache = model_b(x, cache_x=cache)

    try:
        torch.testing.assert_close(out_a_cache, out_b_cache, rtol=1e-5, atol=1e-5)
        print("[Pass] Outputs are numerically identical (With Cache).")
    except AssertionError as e:
        print("[Fail] Outputs differ with cache!")


def benchmark_performance():
    print("\n=== Running Performance Benchmark ===")
    if not torch.cuda.is_available():
        print("Skipping benchmark (CUDA not available)")
        return

    B, C, T, H, W = 4, 64, 32, 128, 128
    out_C = 64
    kernel = 3
    # Padding set to (1,1,1), so T gets padded by 2, H/W by 1
    model_a, model_b = setup_models(C, out_C, kernel, padding=(1, 1, 1))

    x = torch.randn(B, C, T, H, W, device="cuda")

    def run_a():
        return model_a(x)

    def run_b():
        return model_b(x)

    ms_a = triton.testing.do_bench(run_a, rep=100)
    ms_b = triton.testing.do_bench(run_b, rep=100)

    print(f"Implementation A (F.pad):   {ms_a:.3f} ms")
    print(f"Implementation B (Impl.H/W): {ms_b:.3f} ms")

    diff = (ms_a - ms_b) / ms_a * 100
    print(f"Implementation B is {diff:.2f}% faster")


if __name__ == "__main__":
    test_correctness()
    benchmark_performance()

result:

=== Running Correctness Test ===
[Pass] Outputs are numerically identical (No Cache).
[Pass] Outputs are numerically identical (With Cache).

=== Running Performance Benchmark ===
Implementation A (F.pad):   44.787 ms
Implementation B (Impl.H/W): 42.507 ms
Implementation B is 5.09% faster

c8ef avatar Dec 06 '25 07:12 c8ef

@sayakpaul @yiyixuxu @DN6 Please take a look, thanks!

c8ef avatar Dec 06 '25 07:12 c8ef

@c8ef could we also check if the performance further improves with torch.compile?

sayakpaul avatar Dec 08 '25 04:12 sayakpaul

@c8ef could we also check if the performance further improves with torch.compile?

After compilation, the two CausalConv modules have similar performance.

However, I believe this optimization is still useful in scenarios where we cannot compile - for example, when compiling the entire VAE takes too much startup time, or when compiling certain modules may negatively impact video quality.

c8ef avatar Dec 08 '25 06:12 c8ef

Any reason for closing?

sayakpaul avatar Dec 10 '25 09:12 sayakpaul

Any reason for closing?

Sorry, spotty connection and my phone glitched when I touched the screen...

c8ef avatar Dec 10 '25 09:12 c8ef

I'm wondering how we can move this forward. From your reviewers' perspective, is this patch worth implementing?

c8ef avatar Dec 11 '25 16:12 c8ef

Gentle ping.

c8ef avatar Dec 16 '25 05:12 c8ef