Add Wan2.2-S2V: Audio-Driven Cinematic Video Generation
This PR is fixing #12257.
Comparison with the original repo
When I put with torch.amp.autocast('cuda', dtype=torch.bfloat16): onto the transformer only and converted the initial noise's dtype into torch.float32 from torch.bfloat16 in the original repo, the videos seem almost the same. As far as I can see, the original repo's video has an extra blink.
Try WanSpeechToVideoPipeline!
!git clone https://github.com/tolgacangoz/diffusers.git
%cd diffusers
#!git switch "integrations/wan2.2-s2v" # This is constantly changing...
!git switch "wan2.2-s2v"
!pip install pip uv -qU
!uv pip install -e ".[dev]" -q
!uv pip install imageio-ffmpeg ftfy decord ninja packaging kernels -q
# For Flash attention 2:
#!uv pip install flash-attn --no-build-isolation
# For Flash attention 3 in diffusers:
#import os
#os.environ["DIFFUSERS_ENABLE_HUB_KERNELS"] = "yes"
import numpy as np
import torch, os
from diffusers import AutoencoderKLWan, WanSpeechToVideoPipeline
from diffusers.utils import export_to_video, load_image, load_audio, load_video
from transformers import Wav2Vec2ForCTC
model_id = "Wan-AI/Wan2.2-S2V-14B-Diffusers" # will be official
model_id = "tolgacangoz/Wan2.2-S2V-14B-Diffusers"
audio_encoder = Wav2Vec2ForCTC.from_pretrained(model_id, subfolder="audio_encoder", dtype=torch.float32)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanSpeechToVideoPipeline.from_pretrained(
model_id, vae=vae, audio_encoder=audio_encoder, torch_dtype=torch.bfloat16,
)#.to("cuda")
pipe.enable_model_cpu_offload()
#pipe.transformer.set_attention_backend("flash") # FA 2
#pipe.transformer.set_attention_backend("_flash_3_hub") # FA 3
first_frame = load_image("https://raw.githubusercontent.com/Wan-Video/Wan2.2/refs/heads/main/examples/i2v_input.JPG")
audio, sampling_rate = load_audio("https://github.com/Wan-Video/Wan2.2/raw/refs/heads/main/examples/talk.wav")
import math
def get_size_less_than_area(height,
width,
target_area=1024 * 704,
divisor=64):
if height * width <= target_area:
# If the original image area is already less than or equal to the target,
# no resizing is neededβjust padding. Still need to ensure that the padded area doesn't exceed the target.
max_upper_area = target_area
min_scale = 0.1
max_scale = 1.0
else:
# Resize to fit within the target area and then pad to multiples of `divisor`
max_upper_area = target_area # Maximum allowed total pixel count after padding
d = divisor - 1
b = d * (height + width)
a = height * width
c = d**2 - max_upper_area
# Calculate scale boundaries using quadratic equation
min_scale = (-b + math.sqrt(b**2 - 2 * a * c)) / (
2 * a) # Scale when maximum padding is applied
max_scale = math.sqrt(max_upper_area /
(height * width)) # Scale without any padding
# We want to choose the largest possible scale such that the final padded area does not exceed max_upper_area
# Use binary search-like iteration to find this scale
find_it = False
for i in range(100):
scale = max_scale - (max_scale - min_scale) * i / 100
new_height, new_width = int(height * scale), int(width * scale)
# Pad to make dimensions divisible by 64
pad_height = (64 - new_height % 64) % 64
pad_width = (64 - new_width % 64) % 64
pad_top = pad_height // 2
pad_bottom = pad_height - pad_top
pad_left = pad_width // 2
pad_right = pad_width - pad_left
padded_height, padded_width = new_height + pad_height, new_width + pad_width
if padded_height * padded_width <= max_upper_area:
find_it = True
break
if find_it:
return padded_height, padded_width
else:
# Fallback: calculate target dimensions based on aspect ratio and divisor alignment
aspect_ratio = width / height
target_width = int(
(target_area * aspect_ratio)**0.5 // divisor * divisor)
target_height = int(
(target_area / aspect_ratio)**0.5 // divisor * divisor)
# Ensure the result is not larger than the original resolution
if target_width >= width or target_height >= height:
target_width = int(width // divisor * divisor)
target_height = int(height // divisor * divisor)
return target_height, target_width
height, width = get_size_less_than_area(first_frame.height, first_frame.width, target_area=480*832)
prompt = "Einstein singing a song."
output = pipe(
image=first_frame, audio=audio, sampling_rate=sampling_rate,
prompt=prompt, height=height, width=width, num_frames_per_chunk=80,
).frames[0]
export_to_video(output, "video.mp4", fps=16)
import logging, shutil, subprocess
def merge_video_audio(video_path: str, audio_path: str):
"""
Merge the video and audio into a new video, with the duration set to the shorter of the two,
and overwrite the original video file.
Parameters:
video_path (str): Path to the original video file
audio_path (str): Path to the audio file
"""
# set logging
logging.basicConfig(level=logging.INFO)
# check
if not os.path.exists(video_path):
raise FileNotFoundError(f"video file {video_path} does not exist")
if not os.path.exists(audio_path):
raise FileNotFoundError(f"audio file {audio_path} does not exist")
base, ext = os.path.splitext(video_path)
temp_output = f"{base}_temp{ext}"
try:
# create ffmpeg command
command = [
'ffmpeg',
'-y', # overwrite
'-i',
video_path,
'-i',
audio_path,
'-c:v',
'copy', # copy video stream
'-c:a',
'aac', # use AAC audio encoder
'-b:a',
'192k', # set audio bitrate (optional)
'-map',
'0:v:0', # select the first video stream
'-map',
'1:a:0', # select the first audio stream
'-shortest', # choose the shortest duration
temp_output
]
# execute the command
logging.info("Start merging video and audio...")
result = subprocess.run(
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
# check result
if result.returncode != 0:
error_msg = f"FFmpeg execute failed: {result.stderr}"
logging.error(error_msg)
raise RuntimeError(error_msg)
shutil.move(temp_output, video_path)
logging.info(f"Merge completed, saved to {video_path}")
except Exception as e:
if os.path.exists(temp_output):
os.remove(temp_output)
logging.error(f"merge_video_audio failed with error: {e}")
import requests, tempfile
from diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT
response = requests.get("https://github.com/Wan-Video/Wan2.2/raw/refs/heads/main/examples/talk.wav", stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT)
with tempfile.NamedTemporaryFile(delete=False) as talk:
for chunk in response.iter_content(chunk_size=8192):
talk.write(chunk)
talk_file = talk.name
merge_video_audio("video.mp4", talk_file)
@yiyixuxu @sayakpaul @asomoza @dg845 @stevhliu @WanX-Video-1 @Steven-SWZhang @kelseyee @SHYuanBest @J4BEZ @okaris @xziayro-ai @teith @luke14free @lopho @arnold408
hey @tolgacangoz thanks a lot for your work, looks amazing. I tried running the script you attached to the message above and I am getting this error both out of the box and also removing the cpu offloading and doing manual to("cuda"). I had encountered it before and it was an issue with the vae being on a different device and/or dtypes being incompatible, not sure exactly here what's the issue.
error stack trace
Loading pipeline components...: 0%| | 0/7 [00:00<?, ?it/s]`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 3/3 [00:00<00:00, 160.63it/s]
Loading checkpoint shards: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 7/7 [00:00<00:00, 58.33it/s]
Loading pipeline components...: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 7/7 [00:00<00:00, 10.55it/s]
Attention backends are an experimental feature and the API may be subject to change.
Traceback (most recent call last):
File "/home/luca/wan-2.2-s2v/test.py", line 93, in <module>
output = pipe(
^^^^^
File "/home/luca/wan-2.2-s2v/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/luca/wan-2.2-s2v/.venv/lib/python3.12/site-packages/diffusers/pipelines/wan/pipeline_wan_s2v.py", line 842, in __call__
latents_outputs = self.prepare_latents(
^^^^^^^^^^^^^^^^^^^^^
File "/home/luca/wan-2.2-s2v/.venv/lib/python3.12/site-packages/diffusers/pipelines/wan/pipeline_wan_s2v.py", line 566, in prepare_latents
pose_condition = self.load_pose_condition(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/luca/wan-2.2-s2v/.venv/lib/python3.12/site-packages/diffusers/pipelines/wan/pipeline_wan_s2v.py", line 606, in load_pose_condition
pose_condition = retrieve_latents(self.vae.encode(all_poses), sample_mode="argmax")[:, :, 1:]
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/luca/wan-2.2-s2v/.venv/lib/python3.12/site-packages/diffusers/utils/accelerate_utils.py", line 46, in wrapper
return method(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/luca/wan-2.2-s2v/.venv/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 1191, in encode
h = self._encode(x)
^^^^^^^^^^^^^^^
File "/home/luca/wan-2.2-s2v/.venv/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 1158, in _encode
out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/luca/wan-2.2-s2v/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/luca/wan-2.2-s2v/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/luca/wan-2.2-s2v/.venv/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 593, in forward
x = self.conv_in(x, feat_cache[idx])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/luca/wan-2.2-s2v/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/luca/wan-2.2-s2v/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/luca/wan-2.2-s2v/.venv/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 176, in forward
return super().forward(x)
^^^^^^^^^^^^^^^^^^
File "/home/luca/wan-2.2-s2v/.venv/lib/python3.12/site-packages/torch/nn/modules/conv.py", line 717, in forward
return self._conv_forward(input, self.weight, self.bias)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/luca/wan-2.2-s2v/.venv/lib/python3.12/site-packages/torch/nn/modules/conv.py", line 712, in _conv_forward
return F.conv3d(
^^^^^^^^^
NotImplementedError: Could not run 'aten::slow_conv3d_forward' with arguments from the 'CUDA' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::slow_conv3d_forward' is only available for these backends: [CPU, Meta, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMTIA, AutogradMAIA, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradMeta, AutogradNestedTensor, Tracer, AutocastCPU, AutocastMTIA, AutocastMAIA, AutocastXPU, AutocastMPS, AutocastCUDA, FuncTorchBatched, BatchedNestedTensor, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].
CPU: registered at /pytorch/build/aten/src/ATen/RegisterCPU_2.cpp:8588 [kernel]
Meta: registered at /pytorch/aten/src/ATen/core/MetaFallbackKernel.cpp:23 [backend fallback]
BackendSelect: fallthrough registered at /pytorch/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Python: registered at /pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:194 [backend fallback]
FuncTorchDynamicLayerBackMode: registered at /pytorch/aten/src/ATen/functorch/DynamicLayer.cpp:479 [backend fallback]
Functionalize: registered at /pytorch/aten/src/ATen/FunctionalizeFallbackKernel.cpp:375 [backend fallback]
Named: registered at /pytorch/aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
Conjugate: registered at /pytorch/aten/src/ATen/ConjugateFallback.cpp:17 [backend fallback]
Negative: registered at /pytorch/aten/src/ATen/native/NegateFallback.cpp:18 [backend fallback]
ZeroTensor: registered at /pytorch/aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]
ADInplaceOrView: fallthrough registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:104 [backend fallback]
AutogradOther: registered at /pytorch/torch/csrc/autograd/generated/VariableType_4.cpp:19365 [autograd kernel]
AutogradCPU: registered at /pytorch/torch/csrc/autograd/generated/VariableType_4.cpp:19365 [autograd kernel]
AutogradCUDA: registered at /pytorch/torch/csrc/autograd/generated/VariableType_4.cpp:19365 [autograd kernel]
AutogradHIP: registered at /pytorch/torch/csrc/autograd/generated/VariableType_4.cpp:19365 [autograd kernel]
AutogradXLA: registered at /pytorch/torch/csrc/autograd/generated/VariableType_4.cpp:19365 [autograd kernel]
AutogradMPS: registered at /pytorch/torch/csrc/autograd/generated/VariableType_4.cpp:19365 [autograd kernel]
AutogradIPU: registered at /pytorch/torch/csrc/autograd/generated/VariableType_4.cpp:19365 [autograd kernel]
AutogradXPU: registered at /pytorch/torch/csrc/autograd/generated/VariableType_4.cpp:19365 [autograd kernel]
AutogradHPU: registered at /pytorch/torch/csrc/autograd/generated/VariableType_4.cpp:19365 [autograd kernel]
AutogradVE: registered at /pytorch/torch/csrc/autograd/generated/VariableType_4.cpp:19365 [autograd kernel]
AutogradLazy: registered at /pytorch/torch/csrc/autograd/generated/VariableType_4.cpp:19365 [autograd kernel]
AutogradMTIA: registered at /pytorch/torch/csrc/autograd/generated/VariableType_4.cpp:19365 [autograd kernel]
AutogradMAIA: registered at /pytorch/torch/csrc/autograd/generated/VariableType_4.cpp:19365 [autograd kernel]
AutogradPrivateUse1: registered at /pytorch/torch/csrc/autograd/generated/VariableType_4.cpp:19365 [autograd kernel]
AutogradPrivateUse2: registered at /pytorch/torch/csrc/autograd/generated/VariableType_4.cpp:19365 [autograd kernel]
AutogradPrivateUse3: registered at /pytorch/torch/csrc/autograd/generated/VariableType_4.cpp:19365 [autograd kernel]
AutogradMeta: registered at /pytorch/torch/csrc/autograd/generated/VariableType_4.cpp:19365 [autograd kernel]
AutogradNestedTensor: registered at /pytorch/torch/csrc/autograd/generated/VariableType_4.cpp:19365 [autograd kernel]
Tracer: registered at /pytorch/torch/csrc/autograd/generated/TraceType_4.cpp:13560 [kernel]
AutocastCPU: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:322 [backend fallback]
AutocastMTIA: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:466 [backend fallback]
AutocastMAIA: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:504 [backend fallback]
AutocastXPU: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:542 [backend fallback]
AutocastMPS: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:209 [backend fallback]
AutocastCUDA: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:165 [backend fallback]
FuncTorchBatched: registered at /pytorch/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:731 [backend fallback]
BatchedNestedTensor: registered at /pytorch/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:758 [backend fallback]
FuncTorchVmapMode: fallthrough registered at /pytorch/aten/src/ATen/functorch/VmapModeRegistrations.cpp:27 [backend fallback]
Batched: registered at /pytorch/aten/src/ATen/LegacyBatchingRegistrations.cpp:1075 [backend fallback]
VmapMode: fallthrough registered at /pytorch/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
FuncTorchGradWrapper: registered at /pytorch/aten/src/ATen/functorch/TensorWrapper.cpp:210 [backend fallback]
PythonTLSSnapshot: registered at /pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:202 [backend fallback]
FuncTorchDynamicLayerFrontMode: registered at /pytorch/aten/src/ATen/functorch/DynamicLayer.cpp:475 [backend fallback]
PreDispatch: registered at /pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:206 [backend fallback]
PythonDispatcher: registered at /pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:198 [backend fallback]
Thanks for your feedback @luke14free! I was standardizing some device handling and dimensions. I fixed some of them, but I will be able to fix them all tomorrow.
@tolgacangoz is this ready for review now?
We can see yes, except for ~documentation~, ~test file~, ~and the last thing we were discussing~.
Hi @luke14free, it should be OK now. Could you please try again?
gΓΌnaydΔ±n tolga! giving it a spin as we speak, also testing ggufs and first block cache to make sure it's solid all round, will keep you posted and thanks again for your great work π
regular generation works well π https://1nf.sh/tasks/0wqecremywnt6kkprepqxb7rb5, but GGUFs and FBC are not working (I think it's expected as they are not implemented? https://1nf.sh/tasks/7v3yqpqrngn546r2xf9w213v0c)
FBC seems to work when I put pipe.enable_model_cpu_offload() after para_attn.first_block_cache.diffusers_adapters.apply_cache_on_pipe(pipe, residual_diff_threshold=0.08). It accelerates a bit, but I observed restricting movement in the video with residual_diff_threshold=0.08.
import os
os.environ["DIFFUSERS_ENABLE_HUB_KERNELS"] = "YES"
...
pipe.transformer.set_attention_backend("_flash_3_hub")
hi @tolgacangoz
do you want to join a slack channel with us so it is easier to iterate over this?
Since the results are visually different, I think we should first make a unit test for WanS2VTransformer and make sure the outputs are 1:1 with the original - once that's matched we can look into the pipeline itself
so something like this
from Wan2.2 import WanModel_S2V
from diffusers import WanS2VTransformer
their_model = WanModel_S2V(...)
our_model = WanS2VTransforme(...)
# same inputs
latents = torch.randn(...)
timetimestep = torch.tensor([..])
with torch.no_grad():
their_output = their_model(...)
our_output = our_model(...)
torch.allclose(...)
OK, let's join. Tbh, I didn't 100% equalize. I thought that since the output seemed OK, merging the PR seemed high priority. Alright, I will equalize 100% tomorrow. Normally, I have been doing like that; you can look at my Magi-1 PR. I assumed this PR should have been reviewed and merged quickly.
Hi @kelseyee. The official repo at HF will be required. Will you open a placeholder repo, i.e., Wan-AI/Wan2.2-S2V-14B-Diffusers, and then I will be able to open a PR there?
hi @tolgacangoz I can help with the repo once it's ready We are still refactoring the transformer, so the checkpoints are not finalized yet
@tolgacangoz I truly appreciate your great work, and apologize for the late response
When I tried running the script attached to the message, I encountered the following error at pipe.enable_model_cpu_offload():
error stack trace
Loading checkpoint shards: 100%|ββββββββββ| 3/3 [00:00<00:00, 195.72it/s]
Loading checkpoint shards: 100%|ββββββββββ| 7/7 [00:00<00:00, 88.93it/s]
Some weights of the model checkpoint at models/tolgacangoz/Wan2.2-S2V-14B-Diffusers/transformer were not used when initializing WanS2VTransformer3DModel:
['condition_embedder.causal_audio_encoder.encoder.conv2.conv.weight, condition_embedder.causal_audio_encoder.encoder.conv3.conv.bias, condition_embedder.causal_audio_encoder.weights, condition_embedder.causal_audio_encoder.encoder.conv3.conv.weight, condition_embedder.causal_audio_encoder.encoder.conv2.conv.bias']
Some weights of WanS2VTransformer3DModel were not initialized from the model checkpoint at models/tolgacangoz/Wan2.2-S2V-14B-Diffusers/transformer and are newly initialized: ['condition_embedder.causal_audio_encoder.encoder.conv2.conv.conv.bias', 'condition_embedder.causal_audio_encoder.encoder.conv3.conv.conv.weight', 'condition_embedder.causal_audio_encoder.encoder.conv2.conv.conv.weight', 'condition_embedder.causal_audio_encoder.weighted_avg.weights', 'condition_embedder.causal_audio_encoder.encoder.conv3.conv.conv.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Loading pipeline components...: 100%|ββββββββββ| 7/7 [00:00<00:00, 14.41it/s]
...
python3.11/site-packages/torch/nn/modules/module.py:1336, in Module.to.<locals>.convert(t)
1334 except NotImplementedError as e:
1335 if str(e) == "Cannot copy out of meta tensor; no data!":
-> 1336 raise NotImplementedError(
1337 f"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() "
1338 f"when moving module from meta to a different device."
1339 ) from None
1340 else:
1341 raise
NotImplementedError: Cannot copy out of meta tensor; no data! Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() when moving module from meta to a different device.
By slightly modifying the keys of the transformer and index.json, it seemed to help avoid the error
key name remap
import os
import json
from safetensors import safe_open
from safetensors.torch import save_file
from typing import Dict
def remap_safetensors_keys(transformer_dir_path: str, key_remap_dict: Dict[str, str]):
"""
Remap keys in safetensors files and save to a new file
Args:
transformer_dir_path (str): Path to the transformer directory
key_remap_dict (Dict[str, str]): Dictionary for key mapping {old_key: new_key}
"""
# Load index file
index_json_path = os.path.join(transformer_dir_path, 'diffusion_pytorch_model.safetensors.index.json')
with open(index_json_path, 'r') as f:
index_data = json.load(f)
# Get weight_map (this is where the actual key->file mapping is)
weight_map = index_data.get("weight_map", {})
# Build a map from files to keys that need remapping
file_path_key_remap_dict = {}
new_weight_map = {}
remapped_count = 0
# Process each key in weight_map
for key, file_name in weight_map.items():
if key in key_remap_dict:
new_key = key_remap_dict[key]
new_weight_map[new_key] = file_name
# Add to file-specific remap dict
if file_name not in file_path_key_remap_dict:
file_path_key_remap_dict[file_name] = {}
file_path_key_remap_dict[file_name][key] = new_key
else:
# Keep original key
new_weight_map[key] = file_name
# Update index file with new weight_map
index_data["weight_map"] = new_weight_map
with open(index_json_path, 'w') as f:
json.dump(index_data, f, indent=2)
# Process each safetensors file
for file_name, remap_dict in file_path_key_remap_dict.items():
file_path = os.path.join(transformer_dir_path, file_name)
print(f"Loading tensors from {file_name}...")
# Read all tensors
tensors_dict = {}
with safe_open(file_path, framework='pt', device='cpu') as f:
keys = list(f.keys())
print(f"Total keys found: {len(keys)}")
for key in keys:
if key in remap_dict:
new_key = remap_dict[key]
print(f"Remapping: {key} -> {new_key}")
tensors_dict[new_key] = f.get_tensor(key)
remapped_count += 1
else:
tensors_dict[key] = f.get_tensor(key)
print(f"Saving remapped tensors to {file_path}...")
save_file(tensors_dict, file_path)
print("Done!")
return remapped_count
# key mapping dictionary
key_remap_dict = {
"condition_embedder.causal_audio_encoder.encoder.conv2.conv.weight": "condition_embedder.causal_audio_encoder.encoder.conv2.conv.conv.weight",
"condition_embedder.causal_audio_encoder.encoder.conv3.conv.bias": "condition_embedder.causal_audio_encoder.encoder.conv3.conv.conv.bias",
"condition_embedder.causal_audio_encoder.weights": "condition_embedder.causal_audio_encoder.weighted_avg.weights",
"condition_embedder.causal_audio_encoder.encoder.conv3.conv.weight": "condition_embedder.causal_audio_encoder.encoder.conv3.conv.conv.weight",
"condition_embedder.causal_audio_encoder.encoder.conv2.conv.bias": "condition_embedder.causal_audio_encoder.encoder.conv2.conv.conv.bias"
}
# path to transformer directory
transformer_dir_path = 'models/tolgacangoz/Wan2.2-S2V-14B-Diffusers/transformer'
# execute remapping
remapped_count = remap_safetensors_keys(transformer_dir_path, key_remap_dict)
print(f"Successfully remapped {remapped_count} keys!")
"""
Loading tensors from diffusion_pytorch_model-00001-of-00007.safetensors...
Total keys found: 199
Remapping: condition_embedder.causal_audio_encoder.encoder.conv2.conv.bias -> condition_embedder.causal_audio_encoder.encoder.conv2.conv.conv.bias
Remapping: condition_embedder.causal_audio_encoder.encoder.conv2.conv.weight -> condition_embedder.causal_audio_encoder.encoder.conv2.conv.conv.weight
Remapping: condition_embedder.causal_audio_encoder.encoder.conv3.conv.bias -> condition_embedder.causal_audio_encoder.encoder.conv3.conv.conv.bias
Remapping: condition_embedder.causal_audio_encoder.encoder.conv3.conv.weight -> condition_embedder.causal_audio_encoder.encoder.conv3.conv.conv.weight
Remapping: condition_embedder.causal_audio_encoder.weights -> condition_embedder.causal_audio_encoder.weighted_avg.weights
Saving remapped tensors to models/tolgacangoz/Wan2.2-S2V-14B-Diffusers/transformer/diffusion_pytorch_model-00001-of-00007.safetensors...
Done!
Successfully remapped 5 keys!
"""
Once again, I sincerely appreciate your wonderful dedication, and I wish you a blessed and peaceful dayπ
ps. When I ran the code as is, I encountered an out-of-memory (OOM) issue on my machine, so I applied NF4 quantization before executing it, as shown below. I hope this may be of help to others who encounter a similar issue.
code
# !pip install bitsandbytes -qU
# ... same before ...
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
model_id = "Wan-AI/Wan2.2-S2V-14B-Diffusers" # will be official
# download with `hf download tolgacangoz/Wan2.2-S2V-14B-Diffusers --local-dir models/tolgacangoz/Wan2.2-S2V-14B-Diffusers`
model_id = "models/tolgacangoz/Wan2.2-S2V-14B-Diffusers"
audio_encoder = Wav2Vec2ForCTC.from_pretrained(model_id, subfolder="audio_encoder", dtype=torch.float32)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
text_encoder_quant_config = TransformersBitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
)
text_encoder = UMT5EncoderModel.from_pretrained(
model_id, subfolder="text_encoder", quantization_config=text_encoder_quant_config, torch_dtype=torch.bfloat16
)
transformer_quant_config = DiffusersBitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
)
transformer = WanS2VTransformer3DModel.from_pretrained(
model_id, subfolder="transformer", torch_dtype=torch.bfloat16, quantization_config=transformer_quant_config
)
pipe = WanSpeechToVideoPipeline.from_pretrained(
model_id, vae=vae, audio_encoder=audio_encoder, transformer=transformer, text_encoder=text_encoder, torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()
first_frame = load_image("https://raw.githubusercontent.com/Wan-Video/Wan2.2/refs/heads/main/examples/i2v_input.JPG")
audio, sampling_rate = load_audio("https://github.com/Wan-Video/Wan2.2/raw/refs/heads/main/examples/talk.wav")
height, width = get_size_less_than_area(first_frame.height, first_frame.width, target_area=480*832)
prompt = "A Cat is talking."
output = pipe(
image=first_frame, audio=audio, sampling_rate=sampling_rate,
prompt=prompt, height=height, width=width, num_frames_per_chunk=80,
).frames[0]
export_to_video(output, "video.mp4", fps=16)
# ... same after ...
- result
https://github.com/user-attachments/assets/e5c27efe-9573-4631-ba6e-0b396d6ef0e7
This branch is constantly changing. I put a functionally same branch in the script attached to the first message.
Dear @tolgacangoz I appreciate for your hard work again!
While trying to use pipe.enable_sequential_cpu_offload() instead of pipe.enable_model_cpu_offload(),
I encountered an error like below:
error stack trace
File .../torch/utils/_contextlib.py:120, in context_decorator.<locals>.decorate_context(*args, **kwargs)
117 @functools.wraps(func)
118 def decorate_context(*args, **kwargs):
119 with ctx_factory():
--> 120 return func(*args, **kwargs)
File .../diffusers/pipelines/wan/pipeline_wan_s2v.py:882, in WanSpeechToVideoPipeline.__call__(self, image, audio, sampling_rate, prompt, negative_prompt, pose_video_path_or_url, height, width, num_frames_per_chunk, num_inference_steps, guidance_scale, num_videos_per_prompt, generator, latents, prompt_embeds, negative_prompt_embeds, image_embeds, audio_embeds, output_type, return_dict, attention_kwargs, callback_on_step_end, callback_on_step_end_tensor_inputs, max_sequence_length, init_first_frame, sampling_fps, num_chunks)
879 negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
881 if audio_embeds is None:
--> 882 audio_embeds, num_chunks_audio = self.encode_audio(
883 audio, sampling_rate, num_frames_per_chunk, sampling_fps, device
884 )
885 if num_chunks is None or num_chunks > num_chunks_audio:
886 num_chunks = num_chunks_audio
File .../diffusers/pipelines/wan/pipeline_wan_s2v.py:351, in WanSpeechToVideoPipeline.encode_audio(self, audio, sampling_rate, num_frames, fps, device)
349 input_values = self.audio_processor(audio, sampling_rate=sampling_rate, return_tensors="pt").input_values
350 # retrieve logits & take argmax
--> 351 res = self.audio_encoder(input_values.to(self.audio_encoder.device), output_hidden_states=True)
352 feat = torch.cat(res.hidden_states)
354 feat = linear_interpolation(feat, input_fps=50, output_fps=video_rate)
File .../torch/nn/modules/module.py:1773, in Module._wrapped_call_impl(self, *args, **kwargs)
1771 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1772 else:
-> 1773 return self._call_impl(*args, **kwargs)
File .../torch/nn/modules/module.py:1784, in Module._call_impl(self, *args, **kwargs)
1779 # If we don't have any hooks, we want to skip the rest of the logic in
1780 # this function, and just call forward.
1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1782 or _global_backward_pre_hooks or _global_backward_hooks
1783 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1784 return forward_call(*args, **kwargs)
1786 result = None
1787 called_always_called_hooks = set()
File .../accelerate/hooks.py:175, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
173 output = module._old_forward(*args, **kwargs)
174 else:
--> 175 output = module._old_forward(*args, **kwargs)
176 return module._hf_hook.post_forward(module, output)
File .../transformers/models/wav2vec2/modeling_wav2vec2.py:1862, in Wav2Vec2ForCTC.forward(self, input_values, attention_mask, output_attentions, output_hidden_states, return_dict, labels)
1859 if labels is not None and labels.max() >= self.config.vocab_size:
1860 raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
-> 1862 outputs = self.wav2vec2(
1863 input_values,
1864 attention_mask=attention_mask,
1865 output_attentions=output_attentions,
1866 output_hidden_states=output_hidden_states,
1867 return_dict=return_dict,
1868 )
1870 hidden_states = outputs[0]
1871 hidden_states = self.dropout(hidden_states)
File .../torch/nn/modules/module.py:1773, in Module._wrapped_call_impl(self, *args, **kwargs)
1771 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1772 else:
-> 1773 return self._call_impl(*args, **kwargs)
File .../torch/nn/modules/module.py:1784, in Module._call_impl(self, *args, **kwargs)
1779 # If we don't have any hooks, we want to skip the rest of the logic in
1780 # this function, and just call forward.
1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1782 or _global_backward_pre_hooks or _global_backward_hooks
1783 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1784 return forward_call(*args, **kwargs)
1786 result = None
1787 called_always_called_hooks = set()
File .../accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
169 def new_forward(module, *args, **kwargs):
--> 170 args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
171 if module._hf_hook.no_grad:
172 with torch.no_grad():
File .../accelerate/hooks.py:369, in AlignDevicesHook.pre_forward(self, module, *args, **kwargs)
358 self.tied_pointers_to_remove.add((value.data_ptr(), self.execution_device))
360 set_module_tensor_to_device(
361 module,
362 name,
(...) 366 tied_params_map=self.tied_params_map,
367 )
--> 369 return send_to_device(args, self.execution_device), send_to_device(
370 kwargs, self.execution_device, skip_keys=self.skip_keys
371 )
File .../accelerate/utils/operations.py:169, in send_to_device(tensor, device, non_blocking, skip_keys)
167 return tensor.to(device)
168 elif isinstance(tensor, (tuple, list)):
--> 169 return honor_type(
170 tensor, (send_to_device(t, device, non_blocking=non_blocking, skip_keys=skip_keys) for t in tensor)
171 )
172 elif isinstance(tensor, Mapping):
173 if isinstance(skip_keys, str):
File .../accelerate/utils/operations.py:81, in honor_type(obj, generator)
79 return type(obj)(*list(generator))
80 else:
---> 81 return type(obj)(generator)
File .../accelerate/utils/operations.py:170, in <genexpr>(.0)
167 return tensor.to(device)
168 elif isinstance(tensor, (tuple, list)):
169 return honor_type(
--> 170 tensor, (send_to_device(t, device, non_blocking=non_blocking, skip_keys=skip_keys) for t in tensor)
171 )
172 elif isinstance(tensor, Mapping):
173 if isinstance(skip_keys, str):
File .../accelerate/utils/operations.py:153, in send_to_device(tensor, device, non_blocking, skip_keys)
151 device = "npu:0"
152 try:
--> 153 return tensor.to(device, non_blocking=non_blocking)
154 except TypeError: # .to() doesn't accept non_blocking as kwarg
155 return tensor.to(device)
NotImplementedError: Cannot copy out of meta tensor; no data!
After some investigation, I found a workaround that resolved the issue on my end, so I wanted to share the changes I made in case theyβre helpful.
In def encode_audio() in the pipeline_wan_s2v.py
def encode_audio(
self,
audio: PipelineAudioInput,
sampling_rate: int,
num_frames: int,
fps: int = 16,
device: Optional[torch.device] = None,
):
device = device or self._execution_device
video_rate = 30
audio_sample_m = 0
input_values = self.audio_processor(audio, sampling_rate=sampling_rate, return_tensors="pt").input_values
# retrieve logits & take argmax
- res = self.audio_encoder(input_values.to(self.audio_encoder.device), output_hidden_states=True)
+ res = self.audio_encoder(input_values.to(device), output_hidden_states=True)
feat = torch.cat(res.hidden_states)
...
and in def load_pose_condition()
def load_pose_condition(
self, pose_video, num_chunks, num_frames_per_chunk, height, width, latents_mean, latents_std
):
+ device = self._execution_device
+ dtype = self.vae.dtype
if pose_video is not None:
padding_frame_num = num_chunks * num_frames_per_chunk - pose_video.shape[2]
- pose_video = pose_video.to(dtype=self.vae.dtype, device=self.vae.device)
+ pose_video = pose_video.to(dtype=dtype, device=device)
pose_video = torch.cat(
[
pose_video,
-torch.ones(
- [1, 3, padding_frame_num, height, width], dtype=self.vae.dtype, device=self.vae.device
+ [1, 3, padding_frame_num, height, width], dtype=dtype, device=device
),
],
dim=2,
)
pose_video = torch.chunk(pose_video, num_chunks, dim=2)
else:
pose_video = [
- -torch.ones([1, 3, num_frames_per_chunk, height, width], dtype=self.vae.dtype, device=self.vae.device)
+ -torch.ones([1, 3, num_frames_per_chunk, height, width], dtype=dtype, device=device)
]
I hope this would be a little help! Thanks for your dedication and hope you stay healthy and have a peaceful day!
Thanks @J4BEZ, fixed it.
@tolgacangoz Thanks! I am delighted to helpβΊοΈ
Have a peaceful day!
This will be my second official pipeline contribution and my fourth overall, yay :partying_face:
Just a word of encouragement. This technology is actually quite good, and I hope it'll be priotized for review soonish. Here's a video I did with it: https://m.youtube.com/watch?v=N7ARyKKwGfc
Hi @tolgacangoz
I appreciate yout hard work, i tried to use your new pipeline but didn't succeed to make it work like i want
Tried to load a lightx2v lora does not succed :
2025-11-11T17:53:45.0020446Z stdout F Error processing message: 'FrozenDict' object has no attribute 'image_dim' 2025-11-11T17:53:45.0020543Z stderr F pipe.load_lora_weights( 2025-11-11T17:53:45.0020630Z stderr F File "/opt/venv/lib/python3.11/site-packages/diffusers/loaders/lora_pipeline.py", line 4068, in load_lora_weights 2025-11-11T17:53:45.0020644Z stderr F state_dict = self._maybe_expand_t2v_lora_for_i2v( 2025-11-11T17:53:45.0020655Z stderr F ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 2025-11-11T17:53:45.0020666Z stderr F File "/opt/venv/lib/python3.11/site-packages/diffusers/loaders/lora_pipeline.py", line 3999, in _maybe_expand_t2v_lora_for_i2v 2025-11-11T17:53:45.0020682Z stderr F if transformer.config.image_dim is None: 2025-11-11T17:53:45.0020692Z stderr F ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 2025-11-11T17:53:45.0020706Z stderr F AttributeError: 'FrozenDict' object has no attribute 'image_dim'
Without Lora I tried the pipeline with .to("cuda") or pipe.enable_model_cpu_offload() or enable_group_offload and always the same error :
2025-11-12T10:57:13.8436869Z stderr F pipe = WanSpeechToVideoPipeline.from_pretrained( 2025-11-12T10:57:13.8436878Z stderr F ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 2025-11-12T10:57:13.8436893Z stderr F File "/opt/venv/lib/python3.11/site-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn 2025-11-12T10:57:13.8436901Z stderr F return fn(*args, **kwargs) 2025-11-12T10:57:13.8436909Z stderr F ^^^^^^^^^^^^^^^^^^^ 2025-11-12T10:57:13.8436917Z stderr F File "/opt/venv/lib/python3.11/site-packages/diffusers/pipelines/pipeline_utils.py", line 1021, in from_pretrained 2025-11-12T10:57:13.8436925Z stderr F loaded_sub_model = load_sub_model( 2025-11-12T10:57:13.8436933Z stderr F ^^^^^^^^^^^^^^^ 2025-11-12T10:57:13.8436940Z stderr F File "/opt/venv/lib/python3.11/site-packages/diffusers/pipelines/pipeline_loading_utils.py", line 876, in load_sub_model 2025-11-12T10:57:13.8437033Z stderr F loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs) 2025-11-12T10:57:13.8437047Z stderr F ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 2025-11-12T10:57:13.8437058Z stderr F File "/opt/venv/lib/python3.11/site-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn 2025-11-12T10:57:13.8437077Z stderr F return fn(*args, **kwargs) 2025-11-12T10:57:13.8437084Z stderr F ^^^^^^^^^^^^^^^^^^^ 2025-11-12T10:57:13.8437092Z stderr F File "/opt/venv/lib/python3.11/site-packages/diffusers/models/modeling_utils.py", line 1316, in from_pretrained 2025-11-12T10:57:13.8437099Z stderr F dispatch_model(model, **device_map_kwargs) 2025-11-12T10:57:13.8437107Z stderr F File "/opt/venv/lib/python3.11/site-packages/accelerate/big_modeling.py", line 502, in dispatch_model 2025-11-12T10:57:13.8437115Z stderr F model.to(device) 2025-11-12T10:57:13.8437123Z stderr F File "/opt/venv/lib/python3.11/site-packages/diffusers/models/modeling_utils.py", line 1424, in to 2025-11-12T10:57:13.8437131Z stderr F return super().to(*args, **kwargs) 2025-11-12T10:57:13.8437138Z stderr F ^^^^^^^^^^^^^^^^^^^^^^^^^^^ 2025-11-12T10:57:13.8437146Z stderr F File "/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1371, in to 2025-11-12T10:57:13.8437154Z stderr F return self._apply(convert) 2025-11-12T10:57:13.8437166Z stderr F ^^^^^^^^^^^^^^^^^^^^ 2025-11-12T10:57:13.8437372Z stderr F File "/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 930, in _apply 2025-11-12T10:57:13.8437594Z stderr F module._apply(fn) 2025-11-12T10:57:13.8437711Z stderr F File "/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 930, in _apply 2025-11-12T10:57:13.8438220Z stderr F module._apply(fn) 2025-11-12T10:57:13.8438257Z stderr F File "/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 930, in _apply 2025-11-12T10:57:13.8438266Z stderr F module._apply(fn) 2025-11-12T10:57:13.8438273Z stderr F File "/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 957, in _apply 2025-11-12T10:57:13.8438280Z stderr F param_applied = fn(param) 2025-11-12T10:57:13.8438288Z stderr F ^^^^^^^^^ 2025-11-12T10:57:13.8438295Z stderr F File "/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1364, in convert 2025-11-12T10:57:13.8438303Z stderr F raise NotImplementedError( 2025-11-12T10:57:13.8438311Z stderr F NotImplementedError: Cannot copy out of meta tensor; no data! Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() when moving module from meta to a different device.
It's probably not related but when i tested with .to("cuda") i had sageattention activated with pipe.transformer.set_attention_backend("sage")
Hi @zecloud, thanks for reporting this! I will take a look at it tomorrow (+ conflicts below).
Hi @zecloud. AFAIU, there is no Lightning LoRA specifically for the Wan2.2-S2V model. I guess people try to use Wan2.2's high noise transformer's LoRA for S2V? Which one are you using? Could you share reproducible codes?
Hi @tolgacangoz It's only the high noise Lora I saw that on reddit and wanted to test with your pipeline https://civitai.com/models/1909425/wan-22-14b-s2v-ultimate-suite-gguf-and-lightning-speed-with-extended-video-generation?modelVersionId=2161199
My test code didn't use any quantized version this was your demo loading code with this code to load the lora. lightning_hn = hf_hub_download(repo_id="lightx2v/Wan2.2-Distill-Loras" , filename="wan2.2_i2v_A14b_high_noise_lora_rank64_lightx2v_4step_1022.safetensors",local_dir = "/pretrained_models") pipe.load_lora_weights( lightning_hn, adapter_name="light" ) pipe.set_adapters(["light"], adapter_weights=[1.0]) pipe.fuse_lora(adapter_names=["light"], lora_scale=3., components=["transformer"])
I won't able to test it again soon but i let you know if i can.
Without Lora I tried the pipeline with .to("cuda") or pipe.enable_model_cpu_offload() or enable_group_offload and always the same error :
Are you sure that you are using the wan2.2-s2v branch as I emphasized in the first comment?
At the time of my test I merged your branch with my own fork https://github.com/zecloud/diffusers/pull/1 the main difference in my fork, i added repeated_blocks in transformer_vace.py, it shouldn't alter anything to your pipeline. My test environment is executed in a container with cuda 13, latest pytorch and sage attention installed
At the time of my test I merged your branch with my own fork zecloud#1 the main difference in my fork, i added repeated_blocks in transformer_vace.py, it shouldn't alter anything to your pipeline. My test environment is executed in a container with cuda 13, latest pytorch and sage attention installed
It seems that you used this PR's branch. Can you try with the example code shared in the first comment?
Gentle nudge-any particular reason why this PR has not been reviewed and merged?
@yiyixuxu @sayakpaul @asomoza @dg845 @stevhliu
hi @tolga
I actually did go through the PR! It's a a pretty complex integration and it still requires significant refactoring to meet our standard. Unfortunately we don't have the bandwidth right now to take on the remaining refactoring ourselves.