Traceback (most recent call last):
File "/data/yuxiong/DiffSynth-Studio/examples/wanvideo/wan_fun_InP.py", line 65, in
video = pipe(
^^^^^
File "/data/yuxiong/miniconda3/envs/torch26/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/data/yuxiong/DiffSynth-Studio/diffsynth/pipelines/wan_video.py", line 349, in call
noise_pred_posi = model_fn_wan_video(
^^^^^^^^^^^^^^^^^^^
File "/data/yuxiong/DiffSynth-Studio/diffsynth/pipelines/wan_video.py", line 484, in model_fn_wan_video
x = block(x, context, t_mod, freqs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/yuxiong/miniconda3/envs/torch26/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/yuxiong/miniconda3/envs/torch26/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/yuxiong/DiffSynth-Studio/diffsynth/models/wan_video_dit.py", line 216, in forward
x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/yuxiong/miniconda3/envs/torch26/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/yuxiong/miniconda3/envs/torch26/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/yuxiong/DiffSynth-Studio/diffsynth/models/wan_video_dit.py", line 143, in forward
x = self.attn(q, k, v)
^^^^^^^^^^^^^^^^^^
File "/data/yuxiong/miniconda3/envs/torch26/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/yuxiong/miniconda3/envs/torch26/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/yuxiong/DiffSynth-Studio/diffsynth/models/wan_video_dit.py", line 117, in forward
x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/yuxiong/DiffSynth-Studio/diffsynth/models/wan_video_dit.py", line 39, in flash_attention
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/yuxiong/miniconda3/envs/torch26/lib/python3.11/site-packages/einops/einops.py", line 600, in rearrange
return reduce(tensor, pattern, reduction="rearrange", **axes_lengths)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/yuxiong/miniconda3/envs/torch26/lib/python3.11/site-packages/einops/einops.py", line 527, in reduce
backend = get_backend(tensor)
^^^^^^^^^^^^^^^^^^^
File "/data/yuxiong/miniconda3/envs/torch26/lib/python3.11/site-packages/einops/_backends.py", line 59, in get_backend
raise RuntimeError("Tensor type unknown to einops {}".format(type(tensor)))
RuntimeError: Tensor type unknown to einops <class 'tuple'>
I have met the same problem, how to solve it
@lgs00 We have received this issue and will fix it.
@lgs00 We fixed this bug. However, we don't have Hopper GPU to test flash attention3. Could you help us to test it?