torch.compile + channels_last support for Wan 2.2 (T2V / I2V) fails with RuntimeError + Dynamo Unsupported behavior
Describe the bug
Hi, I am trying to optimize Wan 2.2 T2V / I2V inference speed on a single RTX 4090, using:
1 Wan2.2 (Diffusers) 2 LightX2V LoRA 3 flash attention 4 group offload (Diffusers 0.30+) 5 torch.compile(mode="max-autotune", fullgraph=True) / torch.channels_last (as recommended in the docs)
My goal is to achieve maximum throughput on a single 4090 GPU. However, when following the official docs for efficiency: https://huggingface.co/docs/diffusers/api/pipelines/wan#t2v-inference-speed
I hit two different failures:
1. RuntimeError when calling .to(memory_format=torch.channels_last)
According to the docs:
pipeline.transformer.to(memory_format=torch.channels_last)
pipeline.transformer = torch.compile(
pipeline.transformer, mode="max-autotune", fullgraph=True
)
I got the error
Traceback (most recent call last):
File "/data/code/haobang.geng/code/online_storymv_generate/workers/wan.py", line 42, in <module>
pipe.transformer.to(memory_format=torch.channels_last)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/diffusers/models/modeling_utils.py", line 1424, in to
return super().to(*args, **kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1343, in to
return self._apply(convert)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/nn/modules/module.py", line 903, in _apply
module._apply(fn)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/nn/modules/module.py", line 930, in _apply
param_applied = fn(param)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1323, in convert
return t.to(
RuntimeError: required rank 4 tensor to use channels_last format
2 When skipping channels_last and compiling directly, torch.compile fails at runtime
I attempted:# Skipped channels_last
pipe.transformer = torch.compile(
pipe.transformer, mode="max-autotune", fullgraph=True
)
pipe.transformer_2 = torch.compile(
pipe.transformer_2, mode="max-autotune", fullgraph=True
)
I got the error
0%| | 0/6 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/data/code/haobang.geng/code/online_storymv_generate/workers/wan.py", line 89, in <module>
frames = pipe(input_image, "animate", num_inference_steps=6, guidance_scale=1.0).frames[0]
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/diffusers/pipelines/wan/pipeline_wan_i2v.py", line 756, in __call__
noise_pred = current_model(
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
return fn(*args, **kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
return self._torchdynamo_orig_callable(
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
return _compile(
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_utils_internal.py", line 95, in wrapper_function
return function(*args, **kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
out_code = transform_code_object(code, transform)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
transformations(instructions, code_options)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 231, in _fn
return fn(*args, **kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 662, in transform
tracer.run()
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2868, in run
super().run()
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
while self.step():
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
self.dispatch_table[inst.opcode](self, inst)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
return inner_fn(self, inst)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1736, in CALL_FUNCTION_EX
self.call_function(fn, argsvars.items, kwargsvars)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/variables/lazy.py", line 170, in realize_and_forward
return getattr(self.realize(), name)(*args, **kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 858, in call_function
return self.func.call_function(tx, merged_args, merged_kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
return super().call_function(tx, args, kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
tracer.run()
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
while self.step():
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
self.dispatch_table[inst.opcode](self, inst)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
return inner_fn(self, inst)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1736, in CALL_FUNCTION_EX
self.call_function(fn, argsvars.items, kwargsvars)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/variables/lazy.py", line 170, in realize_and_forward
return getattr(self.realize(), name)(*args, **kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/variables/user_defined.py", line 914, in call_function
return func_var.call_function(tx, [obj_var] + args, kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
return super().call_function(tx, args, kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
tracer.run()
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
while self.step():
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
self.dispatch_table[inst.opcode](self, inst)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
return inner_fn(self, inst)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1658, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 378, in call_function
return super().call_function(tx, args, kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
return super().call_function(tx, args, kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3116, in inline_call_
result = InliningInstructionTranslator.check_inlineable(func)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3093, in check_inlineable
unimplemented(
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 317, in unimplemented
raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: 'inline in skipfiles: ModuleGroup.onload_ | _fn /data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py, skipped according trace_rules.lookup SKIP_DIRS'
from user code:
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/diffusers/hooks/hooks.py", line 189, in new_forward
output = function_reference.forward(*args, **kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/diffusers/hooks/hooks.py", line 188, in new_forward
args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/diffusers/hooks/group_offloading.py", line 304, in pre_forward
self.group.onload_()
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
full code
import torch
from diffusers import WanImageToVideoPipeline, DiffusionPipeline, LCMScheduler, UniPCMultistepScheduler
from huggingface_hub import hf_hub_download
import requests
from PIL import Image
from diffusers.loaders.lora_conversion_utils import _convert_non_diffusers_wan_lora_to_diffusers
from io import BytesIO
from diffusers.utils import export_to_video
import safetensors.torch
from diffusers.hooks import apply_group_offloading
import time
# Load image
# image_url = "https://cloud.inference.sh/u/4mg21r6ta37mpaz6ktzwtt8krr/01k1g7k73eebnrmzmc6h0bghq6.png"
# response = requests.get(image_url)
# input_image = Image.open(BytesIO(response.content)).convert("RGB")
input_image = Image.open("/data/code/haobang.geng/code/online_storymv_generate/temp/temp_input/1.jpg").convert("RGB")
warmup_steps = 3
# load pipeline
pipe = WanImageToVideoPipeline.from_pretrained(
"/data/code/haobang.geng/models/Wan2.2-I2V-A14B-Diffusers",
torch_dtype=torch.bfloat16
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)
# load and fuse lora
high_lora_path = "/data/code/haobang.geng/models/WanVideo_comfy/LoRAs/Wan22_Lightx2v/Wan_2_2_I2V_A14B_HIGH_lightx2v_4step_lora_v1030_rank_64_bf16.safetensors"
low_lora_path = "/data/code/haobang.geng/ComfyUI/models/loras/Wan_2_1_lightx2v_I2V_14B_480p_cfg_step_distill_rank64_bf16.safetensors"
pipe.load_lora_weights(high_lora_path, adapter_name='lightx2v_t1')
pipe.set_adapters(["lightx2v_t1"], adapter_weights=[1.0])
pipe.fuse_lora(adapter_names=["lightx2v_t1"], lora_scale=1, components=["transformer"])
if hasattr(pipe, "transformer_2") and pipe.transformer_2 is not None:
org_state_dict = safetensors.torch.load_file(low_lora_path)
converted_state_dict = _convert_non_diffusers_wan_lora_to_diffusers(org_state_dict)
pipe.transformer_2.load_lora_adapter(converted_state_dict, adapter_name="lightx2v_t2")
pipe.transformer_2.set_adapters(["lightx2v_t2"], weights=[1.0])
pipe.fuse_lora(adapter_names=["lightx2v_t2"], lora_scale=1., components=["transformer_2"])
pipe.unload_lora_weights()
# torch.compile
# pipe.transformer.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(
pipe.transformer, mode="max-autotune", fullgraph=True
)
# pipe.transformer_2.to(memory_format=torch.channels_last)
pipe.transformer_2 = torch.compile(
pipe.transformer_2, mode="max-autotune", fullgraph=True
)
# group offload
apply_group_offloading(
pipe.transformer,
offload_type="leaf_level",
offload_device=torch.device("cpu"),
onload_device=torch.device("cuda"),
use_stream=True,
)
apply_group_offloading(
pipe.transformer_2,
offload_type="leaf_level",
offload_device=torch.device("cpu"),
onload_device=torch.device("cuda"),
use_stream=True,
)
apply_group_offloading(
pipe.text_encoder,
offload_device=torch.device("cpu"),
onload_device=torch.device("cuda"),
offload_type="leaf_level",
use_stream=True,
)
apply_group_offloading(
pipe.vae,
offload_device=torch.device("cpu"),
onload_device=torch.device("cuda"),
offload_type="leaf_level",
use_stream=True,
)
# set effeicent attention
pipe.transformer.set_attention_backend("flash")
# for i in range(warmup_steps):
# frames = pipe(input_image, "animate", num_inference_steps=6, guidance_scale=1.0).frames[0]
start_time = time.time()
frames = pipe(input_image, "animate", num_inference_steps=6, guidance_scale=1.0).frames[0]
end_time = time.time()
print(f"Time taken: {end_time - start_time} seconds")
export_to_video(frames, "/data/code/haobang.geng/code/online_storymv_generate/temp/temp_output/output.mp4",fps=15)
Request
1 Can Wan2.2 Transformer support channels_last? (Currently incompatible with Rank ≠4 tensors) 2 Can the team patch torch.compile compatibility for Wan2.2 T2V/I2V transformers?
3 Are there recommended compiler flags (e.g., dynamic=True, fullgraph=False, etc.) that work reliably for Wan2.2?
Reproduction
python fullcode.py
Logs
System Info
torch 2.6.0+cu124 torchaudio 2.6.0+cu124 torchsde 0.2.6 torchvision 0.21.0+cu124 diffusers 0.35.2
Who can help?
No response
Hi @Passenger12138 could try enabling group offloading before compiling.
Hi, @DN6, I tried the solution you suggested such as
# group offload
apply_group_offloading(
pipe.transformer,
offload_type="leaf_level",
offload_device=torch.device("cpu"),
onload_device=torch.device("cuda"),
use_stream=True,
)
apply_group_offloading(
pipe.transformer_2,
offload_type="leaf_level",
offload_device=torch.device("cpu"),
onload_device=torch.device("cuda"),
use_stream=True,
)
apply_group_offloading(
pipe.text_encoder,
offload_device=torch.device("cpu"),
onload_device=torch.device("cuda"),
offload_type="leaf_level",
use_stream=True,
)
apply_group_offloading(
pipe.vae,
offload_device=torch.device("cpu"),
onload_device=torch.device("cuda"),
offload_type="leaf_level",
use_stream=True,
)
# torch.compile
pipe.transformer.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(
pipe.transformer, mode="max-autotune", fullgraph=True
)
pipe.transformer_2.to(memory_format=torch.channels_last)
pipe.transformer_2 = torch.compile(
pipe.transformer_2, mode="max-autotune", fullgraph=True
)
# set effeicent attention
pipe.transformer.set_attention_backend("flash")
# for i in range(warmup_steps):
# frames = pipe(input_image, "animate", num_inference_steps=6, guidance_scale=1.0).frames[0]
start_time = time.time()
frames = pipe(input_image, "animate", num_inference_steps=6, guidance_scale=1.0).frames[0]
end_time = time.time()
print(f"Time taken: {end_time - start_time} seconds")
export_to_video(frames, "/data/code/haobang.geng/code/online_storymv_generate/temp/temp_output/output.mp4",fps=15)
but I still encountered a problem
Traceback (most recent call last):
File "/data/code/haobang.geng/code/online_storymv_generate/workers/wan.py", line 73, in <module>
pipe.transformer.to(memory_format=torch.channels_last)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/diffusers/models/modeling_utils.py", line 1424, in to
return super().to(*args, **kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1343, in to
return self._apply(convert)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/nn/modules/module.py", line 903, in _apply
module._apply(fn)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/nn/modules/module.py", line 930, in _apply
param_applied = fn(param)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1323, in convert
return t.to(
RuntimeError: required rank 4 tensor to use channels_last format
Hey @Passenger12138 @DN6,
I spent some time looking into this and verifying the reproduction steps. It looks like there are two fundamental conflicts here that make this specific combination of optimizations really tricky on "consumer hardware".
-
The
channels_lastissues: The RuntimeError happens because torch.channels_last strictly expects Rank 4 tensors (NCHW), but Wan is a video model using Rank 5 tensors. I actually tried switching to memory_format=torch.channels_last_3d to see if that was a quick fix. It didn't crash immediately during the .to() call, but torch.compile still failed later when trying to trace it, resulting in a huge Unsupported traceback. -
torch.compilevs. Offloading: There seems to be a hard conflict between Dynamo (the compiler) and any form of offloading. Whether I used group_offloading or enable_model_cpu_offload, Dynamo errors out because it can't trace the Python hooks that handle the memory swapping (onload_ or accelerate hooks).
I also tried running the optimization example from the official docs for Wan 2.1 T2V 1.3B. On a 16GB GPU, pipeline.to("cuda") causes an immediate OOM (the T5 encoder is huge). If I switch to pipeline.enable_model_cpu_offload() to fix the OOM, I hit the exact same compilation errors mentioned above.
It seems like for now, we have to pick one: either offloading (to fit the model in VRAM) or compilation (for speed), but getting both working together isn't currently supported. It might be worth updating the docs example to reflect this so users don't run into these crashes out of the box.
@sayakpaul any thoughts?
torch.compile vs. Offloading: There seems to be a hard conflict between Dynamo (the compiler) and any form of offloading. Whether I used group_offloading or enable_model_cpu_offload, Dynamo errors out because it can't trace the Python hooks that handle the memory swapping (onload_ or accelerate hooks).
You can compile with offloading you just have to set fullgraph=False. We test for this stuff, see here:
https://github.com/huggingface/diffusers/blob/b010a8ce0c5b5288045045f9f79c496899e80b5a/tests/models/test_modeling_common.py#L2098
For 16GB GPUs, compilation might actually have fewer benefits (see this guide):
Inductor is optimized towards A100 and H100 performance. Even V100 are out of the happy path, and consumer cards are especially likely to have problems. A lot of this is due to targeting Triton, which has a relatively limited set of cards it is highly optimized for. We would like to do better, especially on consumer cards, but this is a big project, external contributions welcome!
So, I would suggest sticking to offloading.
Thanks for your insights @sayakpaul.
However, the documentation code should use channels_last_3d instead of channels_last to support the video architecture, shouldn't it?
Getting the same issue with Wan-2.1-14B, Are the Wan-2.1 docs outdated?
from diffusers.utils import export_to_video, load_video
from diffusers import AutoencoderKLWan, WanVideoToVideoPipeline
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
# Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers
model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanVideoToVideoPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
flow_shift = 3.0 # 5.0 for 720P, 3.0 for 480P
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
pipe.to("cuda")
pipe.transformer.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(
pipe.transformer, mode="max-autotune", fullgraph=True
)
prompt = "First-person POV, realistic footage, handheld camera, natural lighting, raw unedited feel, high quality, 4k, authentic atmosphere."
negative_prompt = "Gameplay, unrealistic, Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
HEIGHT = 512
WIDTH = 512
CLIP_FRAMES = 121 # Number of frames to extract per clip
NUM_INFERENCE_STEPS = 40
SEED = 1
FPS = 60
video = load_video("outputs-1/input_videos/1.mp4")
print(video[0].shape)
output = pipe(
video=video,
prompt=prompt,
negative_prompt=negative_prompt,
height=HEIGHT,
width=WIDTH,
guidance_scale=5.0,
strength=0.7,
).frames[0]
export_to_video(output, "output.mp4", fps=FPS)
File "/home/sky/anmol/owl-style-transfer/tests/wanv2v.py", line 13, in <module>
pipe.transformer.to(memory_format=torch.channels_last)
File "/home/sky/miniconda3/envs/style/lib/python3.12/site-packages/diffusers/models/modeling_utils.py", line 1435, in to
return super().to(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/sky/miniconda3/envs/style/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1371, in to
return self._apply(convert)
^^^^^^^^^^^^^^^^^^^^
File "/home/sky/miniconda3/envs/style/lib/python3.12/site-packages/torch/nn/modules/module.py", line 930, in _apply
module._apply(fn)
File "/home/sky/miniconda3/envs/style/lib/python3.12/site-packages/torch/nn/modules/module.py", line 957, in _apply
param_applied = fn(param)
^^^^^^^^^
File "/home/sky/miniconda3/envs/style/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1351, in convert
return t.to(
^^^^^
RuntimeError: required rank 4 tensor to use channels_last format```
When I try to use channels_last_3d this is the error I receive
File "/home/sky/anmol/owl-style-transfer/tests/wanv2v.py", line 43, in <module>
output = pipe(
^^^^^
File "/home/sky/miniconda3/envs/style/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/sky/miniconda3/envs/style/lib/python3.12/site-packages/diffusers/pipelines/wan/pipeline_wan_video2video.py", line 679, in __call__
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
~~~~~~~~~~~^~~~~~~~~~~~~~
RuntimeError: Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. Stack trace: File "/home/sky/miniconda3/envs/style/lib/python3.12/site-packages/diffusers/models/transformers/transformer_wan.py", line 710, in forward
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3). To prevent overwriting, clone the tensor outside of torch.compile() or call torch.compiler.cudagraph_mark_step_begin() before each model invocation.