iree icon indicating copy to clipboard operation
iree copied to clipboard

[ROCM][Tracker] Wan2.1 Autoencoder3d performance - MI300x

Open monorimet opened this issue 7 months ago • 3 comments

This issue tracks performance burndown of the 3d autoencoder for the Wan2.1 video generation pipeline.

This model differs from the SDXL VAE primarily in that it operates on slices of a video latent input, looping over the "frame" dimension and processing each "latent frame" separately. We also see significantly different dispatches formed (A lot more Conv3d, different shapes). The target is also in bf16 precision. There is more work to be done on optimizing this export -- currently, the export process unrolls the loop over video frame slices into a static number of repetitions matching the number of frames, but we should probably emit an entrypoint for processing a single frame, and another to perform initialization / scf.for loop over frames / postprocess with a dynamic number of input frames. It's difficult to accurately emit this with the turbine dynamo export stack.

That being said, I have run benchmarks on a target configuration (512x512 output, 81 frames encode, 21 frames decode) and have preliminary results for VAE encode which follow:

Artifacts required for reproducing results: Weights: https://sharkpublic.blob.core.windows.net/sharkpublic/halo-models/diffusion/wan2_1_14b_480p/wan2_1_vae_bf16.irpa VMFB: https://sharkpublic.blob.core.windows.net/sharkpublic/halo-models/diffusion/wan2_1_14b_480p/wan2_1_vae_512x512_gfx942.vmfb Sample inputs: https://sharkpublic.blob.core.windows.net/sharkpublic/halo-models/diffusion/wan2_1_14b_480p/vae_encode_input.npy

Optional: MLIR: https://sharkpublic.blob.core.windows.net/sharkpublic/halo-models/diffusion/wan2_1_14b_480p/wan2_1_vae_512x512.mlir

You may also use azure CLI to download them as a batch:

mkdir wan2_1_14b_480p
az storage blob download-batch --account-name sharkpublic --source sharkpublic --destination ./wan2_1_14b_480p --pattern "halo-model
s/diffusion/wan2_1_14b_480p/*"
cd wan2_1_14b_480p

Compile command used:

iree-compile --iree-hal-target-backends=rocm --iree-hip-target=gfx942 --iree-execution-model=async-external --iree-dispatch-creation-enable-fuse-horizontal-contractions=0  --iree-flow-inline-constants-max-byte-length=16 --iree-global-opt-propagate-transposes=1 --iree-opt-const-eval=0 --iree-opt-outer-dim-concat=1 --iree-opt-aggressively-propagate-transposes=1 --iree-dispatch-creation-enable-aggressive-fusion --iree-hal-force-indirect-command-buffers --iree-llvmgpu-enable-prefetch=1 --iree-opt-data-tiling=0 --iree-hal-memoization=1 --iree-opt-strip-assertions --iree-codegen-llvmgpu-early-tile-and-fuse-matmul=1 --iree-stream-resource-memory-model=discrete --iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-pad-to-intrinsics, util.func(iree-preprocessing-generalize-linalg-matmul-experimental),iree-preprocessing-convert-conv-filter-to-channels-last{filter-layout=fhwc})' wan2_1_vae_512x512.mlir -o wan2_1_vae_512x512_gfx942.vmfb

Encode (81 frames, 512x512, bf16):

Benchmark command and output (note: latencies are affected by runtime tracing):
HIP_VISIBLE_DEVICES=1 IREE_PY_RUNTIME=tracy iree-benchmark-module --module=wan2_1_vae_512x512_gfx942.vmfb --input=@vae_encode_input.npy --function=encode --device=hip://0 --parameters=model=wan2_1_vae_bf16.irpa --benchmark_repetitions=3
-- Using Tracy runtime (IREE_PY_RUNTIME=tracy)
2025-05-14T16:57:17+00:00
Running /home/eagarvey/shark-ai/.venv/lib/python3.12/site-packages/iree/_runtime_libs_tracy/iree-benchmark-module
Run on (128 X 3762.99 MHz CPU s)
CPU Caches:
  L1 Data 32 KiB (x128)
  L1 Instruction 32 KiB (x128)
  L2 Unified 1024 KiB (x128)
  L3 Unified 32768 KiB (x16)
Load Average: 3.68, 3.89, 5.81
***WARNING*** CPU scaling is enabled, the benchmark real time measurements may be noisy and will incur extra overhead.
--------------------------------------------------------------------------------------------------
Benchmark                                        Time             CPU   Iterations UserCounters...
--------------------------------------------------------------------------------------------------
BM_encode/process_time/real_time              2595 ms         4572 ms            1 items_per_second=0.385282/s
BM_encode/process_time/real_time              2597 ms         4651 ms            1 items_per_second=0.385042/s
BM_encode/process_time/real_time              2599 ms         4709 ms            1 items_per_second=0.384763/s
BM_encode/process_time/real_time_mean         2597 ms         4644 ms            3 items_per_second=0.385029/s
BM_encode/process_time/real_time_median       2597 ms         4651 ms            3 items_per_second=0.385042/s
BM_encode/process_time/real_time_stddev       1.76 ms         68.4 ms            3 items_per_second=260.177u/s
BM_encode/process_time/real_time_cv           0.07 %          1.47 %             3 items_per_second=0.07%
Top dispatches (screenshot of tracy results):

Image

Tracy profile: DDL link

Decode (1 frame, 512x512, bf16):

Attention seems responsible for performance issues with wan2.1 VAE decode.

Top dispatches (screenshot of tracy results):

Image

Additional Notes

Compiler version:

IREE (https://iree.dev):
  IREE compiler version 3.5.0rc20250514 @ d63e15e15509784de68f1e39f86f78c980031dda
  LLVM version 21.0.0git
  Optimized build

Torch version (affects exported MLIR):

Version: 2.5.1+rocm6.2

monorimet avatar May 14 '25 18:05 monorimet

Exports of the original post's MLIR can be reproduced by following the instructions on the shark-ai feature branch @ wan_exports: https://github.com/nod-ai/shark-ai/blob/wan_exports/sharktank/sharktank/torch_exports/wan/README.md

This is the nn.module we are exporting through iree-turbine aot: orig_vae.py#L506

This is the export entrypoint for VAE: export.py#L268-L271 Line 270 instantiates the nn.module and sample inputs, and Line 271 feeds them into the generalized export function.

Currently, the attention dispatch in VAE decode is performing below expectations -- opinions are requested as to whether this should be improved in the compiler or if we should try to export better attention shapes AOT. @Groverkss

monorimet avatar May 27 '25 16:05 monorimet

The next step here, if I understand correctly, is to improve code generation for the attention shape with head dim 384 used in this VAE decode. I can help with this if we are short on hands, but would need some direction as to what changes should be looked into. @Groverkss anything entry-level I can take on here or should I leave it in your queue?

monorimet avatar May 29 '25 16:05 monorimet

I looked at the attention IR, it's going down the memory bound attention pipeline. The reason is that our attention/mma heuristics are not best at checking if the copy from shared memory is aligned to workgroup size. I can send a fix tommorow.

Groverkss avatar Jun 10 '25 17:06 Groverkss

@Groverkss I am reimplementing this model in sharktank to get a better export, but I suspect we will still need to support these cases. Has there been any change since last left off? I can file a separate issue with the attention dispatch if that would help.

monorimet avatar Sep 22 '25 15:09 monorimet