Description
ring parallel can't work in wan2.2,
RingAttnWeight.apply parameter img_qkv_len should be changed to slice_qkv_len.
Steps to Reproduce
pipe = LightX2VPipeline(
model_path="/root/wan22",
model_cls="wan2.2_moe_distill",
task="i2v",
low_noise_original_ckpt="/root/wan22/low_noise_model/wan2.2_i2v_A14b_low_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors",
high_noise_original_ckpt="/root/wan22/high_noise_model/wan2.2_i2v_A14b_high_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors",
# Distilled weights: For wan2.1, only need to specify dit_original_ckpt="/path/to/wan2.1_i2v_720p_lightx2v_4step.safetensors"
)
...
pipe.enable_parallel(
seq_p_size=8, # Sequence parallel size
seq_p_attn_type="ring", # Sequence parallel attention type
)
pipe.create_generator(
attn_mode="sage_attn2",
infer_steps=4,
height=720, # Can be set to 720 for higher resolution
width=1280, # Can be set to 1280 for higher resolution
num_frames=81,
guidance_scale=1,
sample_shift=5.0,
)
Expected Result
The pipeline was successfully created.
Actual Result
crashed.
Environment Information
- Ubuntu 22.04
- torch 2.9.0
- Commit ID: 138876d06888a8e573f011c44ec2a9eedd622ec2
Log Information
[rank1]: Traceback (most recent call last):
[rank1]: File "/root/test.py", line 78, in
[rank1]: pipe.generate(
[rank1]: File "/opt/miniconda3/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank1]: return func(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/LightX2V/lightx2v/pipeline.py", line 348, in generate
[rank1]: self.runner.run_pipeline(input_info)
[rank1]: File "/root/LightX2V/lightx2v/models/runners/default_runner.py", line 463, in run_pipeline
[rank1]: gen_video_final = self.run_main()
[rank1]: ^^^^^^^^^^^^^^^
[rank1]: File "/root/LightX2V/lightx2v/models/runners/default_runner.py", line 368, in run_main
[rank1]: latents = self.run_segment(segment_idx)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/LightX2V/lightx2v/utils/memory_profiler.py", line 18, in wrapper
[rank1]: result = func(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/LightX2V/lightx2v/models/runners/default_runner.py", line 183, in run_segment
[rank1]: self.model.infer(self.inputs)
[rank1]: File "/root/LightX2V/lightx2v/models/runners/wan/wan_distill_runner.py", line 91, in infer
[rank1]: self.model[self.cur_model_index].infer(inputs)
[rank1]: File "/opt/miniconda3/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank1]: return func(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/LightX2V/lightx2v/models/networks/wan/model.py", line 445, in infer
[rank1]: self.scheduler.noise_pred = self._infer_cond_uncond(inputs, infer_condition=True)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/LightX2V/lightx2v/utils/custom_compiler.py", line 47, in wrapper
[rank1]: return state["original_func"](self, *args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/miniconda3/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank1]: return func(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/LightX2V/lightx2v/models/networks/wan/model.py", line 464, in _infer_cond_uncond
[rank1]: x = self.transformer_infer.infer(self.transformer_weights, pre_infer_out)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/miniconda3/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank1]: return func(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/LightX2V/lightx2v/models/networks/wan/infer/transformer_infer.py", line 68, in infer
[rank1]: x = self.infer_main_blocks(weights.blocks, pre_infer_out)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/LightX2V/lightx2v/models/networks/wan/infer/transformer_infer.py", line 72, in infer_main_blocks
[rank1]: x = self.infer_func(blocks, pre_infer_out.x, pre_infer_out)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/LightX2V/lightx2v/models/networks/wan/infer/offload/transformer_infer.py", line 47, in infer_with_blocks_offload
[rank1]: x = self.infer_block(self.offload_manager.cuda_buffers[0], x, pre_infer_out)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/LightX2V/lightx2v/models/networks/wan/infer/transformer_infer.py", line 113, in infer_block
[rank1]: y_out = self.infer_self_attn(
[rank1]: ^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/LightX2V/lightx2v/models/networks/wan/infer/transformer_infer.py", line 178, in infer_self_attn
[rank1]: attn_out = phase.self_attn_1_parallel.apply(
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: TypeError: RingAttnWeight.apply() got an unexpected keyword argument 'slice_qkv_len'
Additional Information
input in "LightX2V/lightx2v/models/networks/wan/infer/transformer_infer.py", line 178
attn_out = phase.self_attn_1_parallel.apply(
q=q,
k=k,
v=v,
slice_qkv_len=img_qkv_len,
cu_seqlens_qkv=cu_seqlens_qkv,
attention_module=phase.self_attn_1,
seq_p_group=self.seq_p_group,
use_fp8_comm=self.seq_p_fp8_comm,
model_cls=self.config["model_cls"],
)
but
class RingAttnWeight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None, seq_p_group=None, model_cls=None, use_fp8_comm=False):
img_qkv_len should be changed to slice_qkv_len.