DiffSynth-Studio icon indicating copy to clipboard operation
DiffSynth-Studio copied to clipboard

wan_video_dit issue with flash_attn 3

Open matabear-wyx opened this issue 8 months ago • 3 comments

Traceback (most recent call last): File "/data/yuxiong/DiffSynth-Studio/examples/wanvideo/wan_fun_InP.py", line 65, in video = pipe( ^^^^^ File "/data/yuxiong/miniconda3/envs/torch26/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/data/yuxiong/DiffSynth-Studio/diffsynth/pipelines/wan_video.py", line 349, in call noise_pred_posi = model_fn_wan_video( ^^^^^^^^^^^^^^^^^^^ File "/data/yuxiong/DiffSynth-Studio/diffsynth/pipelines/wan_video.py", line 484, in model_fn_wan_video x = block(x, context, t_mod, freqs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/yuxiong/miniconda3/envs/torch26/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/yuxiong/miniconda3/envs/torch26/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/yuxiong/DiffSynth-Studio/diffsynth/models/wan_video_dit.py", line 216, in forward x = self.gate(x, gate_msa, self.self_attn(input_x, freqs)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/yuxiong/miniconda3/envs/torch26/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/yuxiong/miniconda3/envs/torch26/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/yuxiong/DiffSynth-Studio/diffsynth/models/wan_video_dit.py", line 143, in forward x = self.attn(q, k, v) ^^^^^^^^^^^^^^^^^^ File "/data/yuxiong/miniconda3/envs/torch26/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/yuxiong/miniconda3/envs/torch26/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/yuxiong/DiffSynth-Studio/diffsynth/models/wan_video_dit.py", line 117, in forward x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/yuxiong/DiffSynth-Studio/diffsynth/models/wan_video_dit.py", line 39, in flash_attention x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/yuxiong/miniconda3/envs/torch26/lib/python3.11/site-packages/einops/einops.py", line 600, in rearrange return reduce(tensor, pattern, reduction="rearrange", **axes_lengths) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/yuxiong/miniconda3/envs/torch26/lib/python3.11/site-packages/einops/einops.py", line 527, in reduce backend = get_backend(tensor) ^^^^^^^^^^^^^^^^^^^ File "/data/yuxiong/miniconda3/envs/torch26/lib/python3.11/site-packages/einops/_backends.py", line 59, in get_backend raise RuntimeError("Tensor type unknown to einops {}".format(type(tensor))) RuntimeError: Tensor type unknown to einops <class 'tuple'>

matabear-wyx avatar Apr 11 '25 01:04 matabear-wyx

I have met the same problem, how to solve it

lgs00 avatar Apr 14 '25 03:04 lgs00

@lgs00 We have received this issue and will fix it.

Artiprocher avatar Apr 16 '25 02:04 Artiprocher

@lgs00 We fixed this bug. However, we don't have Hopper GPU to test flash attention3. Could you help us to test it?

Artiprocher avatar Apr 17 '25 08:04 Artiprocher