LightX2V icon indicating copy to clipboard operation
LightX2V copied to clipboard

[Bug] ring parallel can't work in wan2.2, RingAttnWeight.apply() got an unexpected keyword argument 'slice_qkv_len'

Open laiaqwq opened this issue 1 week ago • 0 comments

Description

ring parallel can't work in wan2.2, RingAttnWeight.apply parameter img_qkv_len should be changed to slice_qkv_len.

Steps to Reproduce

pipe = LightX2VPipeline(
        model_path="/root/wan22",
        model_cls="wan2.2_moe_distill",
        task="i2v",
        low_noise_original_ckpt="/root/wan22/low_noise_model/wan2.2_i2v_A14b_low_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors",
        high_noise_original_ckpt="/root/wan22/high_noise_model/wan2.2_i2v_A14b_high_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors",
        # Distilled weights: For wan2.1, only need to specify dit_original_ckpt="/path/to/wan2.1_i2v_720p_lightx2v_4step.safetensors"
    )
...
pipe.enable_parallel(
    seq_p_size=8,                    # Sequence parallel size
    seq_p_attn_type="ring",       # Sequence parallel attention type
)
pipe.create_generator(
    attn_mode="sage_attn2",
    infer_steps=4,
    height=720,  # Can be set to 720 for higher resolution
    width=1280,  # Can be set to 1280 for higher resolution
    num_frames=81,
    guidance_scale=1,
    sample_shift=5.0,
    )

Expected Result

The pipeline was successfully created.

Actual Result

crashed.

Environment Information

  • Ubuntu 22.04
  • torch 2.9.0
  • Commit ID: 138876d06888a8e573f011c44ec2a9eedd622ec2

Log Information

[rank1]: Traceback (most recent call last): [rank1]: File "/root/test.py", line 78, in [rank1]: pipe.generate( [rank1]: File "/opt/miniconda3/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context [rank1]: return func(*args, **kwargs) [rank1]: ^^^^^^^^^^^^^^^^^^^^^ [rank1]: File "/root/LightX2V/lightx2v/pipeline.py", line 348, in generate [rank1]: self.runner.run_pipeline(input_info) [rank1]: File "/root/LightX2V/lightx2v/models/runners/default_runner.py", line 463, in run_pipeline [rank1]: gen_video_final = self.run_main() [rank1]: ^^^^^^^^^^^^^^^ [rank1]: File "/root/LightX2V/lightx2v/models/runners/default_runner.py", line 368, in run_main [rank1]: latents = self.run_segment(segment_idx) [rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank1]: File "/root/LightX2V/lightx2v/utils/memory_profiler.py", line 18, in wrapper [rank1]: result = func(*args, **kwargs) [rank1]: ^^^^^^^^^^^^^^^^^^^^^ [rank1]: File "/root/LightX2V/lightx2v/models/runners/default_runner.py", line 183, in run_segment [rank1]: self.model.infer(self.inputs) [rank1]: File "/root/LightX2V/lightx2v/models/runners/wan/wan_distill_runner.py", line 91, in infer [rank1]: self.model[self.cur_model_index].infer(inputs) [rank1]: File "/opt/miniconda3/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context [rank1]: return func(*args, **kwargs) [rank1]: ^^^^^^^^^^^^^^^^^^^^^ [rank1]: File "/root/LightX2V/lightx2v/models/networks/wan/model.py", line 445, in infer [rank1]: self.scheduler.noise_pred = self._infer_cond_uncond(inputs, infer_condition=True) [rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank1]: File "/root/LightX2V/lightx2v/utils/custom_compiler.py", line 47, in wrapper [rank1]: return state["original_func"](self, *args, **kwargs) [rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank1]: File "/opt/miniconda3/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context [rank1]: return func(*args, **kwargs) [rank1]: ^^^^^^^^^^^^^^^^^^^^^ [rank1]: File "/root/LightX2V/lightx2v/models/networks/wan/model.py", line 464, in _infer_cond_uncond [rank1]: x = self.transformer_infer.infer(self.transformer_weights, pre_infer_out) [rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank1]: File "/opt/miniconda3/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context [rank1]: return func(*args, **kwargs) [rank1]: ^^^^^^^^^^^^^^^^^^^^^ [rank1]: File "/root/LightX2V/lightx2v/models/networks/wan/infer/transformer_infer.py", line 68, in infer [rank1]: x = self.infer_main_blocks(weights.blocks, pre_infer_out) [rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank1]: File "/root/LightX2V/lightx2v/models/networks/wan/infer/transformer_infer.py", line 72, in infer_main_blocks [rank1]: x = self.infer_func(blocks, pre_infer_out.x, pre_infer_out) [rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank1]: File "/root/LightX2V/lightx2v/models/networks/wan/infer/offload/transformer_infer.py", line 47, in infer_with_blocks_offload [rank1]: x = self.infer_block(self.offload_manager.cuda_buffers[0], x, pre_infer_out) [rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank1]: File "/root/LightX2V/lightx2v/models/networks/wan/infer/transformer_infer.py", line 113, in infer_block [rank1]: y_out = self.infer_self_attn( [rank1]: ^^^^^^^^^^^^^^^^^^^^^ [rank1]: File "/root/LightX2V/lightx2v/models/networks/wan/infer/transformer_infer.py", line 178, in infer_self_attn [rank1]: attn_out = phase.self_attn_1_parallel.apply( [rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank1]: TypeError: RingAttnWeight.apply() got an unexpected keyword argument 'slice_qkv_len'

Additional Information

input in "LightX2V/lightx2v/models/networks/wan/infer/transformer_infer.py", line 178

            attn_out = phase.self_attn_1_parallel.apply(
                q=q,
                k=k,
                v=v,
                slice_qkv_len=img_qkv_len,
                cu_seqlens_qkv=cu_seqlens_qkv,
                attention_module=phase.self_attn_1,
                seq_p_group=self.seq_p_group,
                use_fp8_comm=self.seq_p_fp8_comm,
                model_cls=self.config["model_cls"],
            )

but

class RingAttnWeight(AttnWeightTemplate):
    def __init__(self):
        self.config = {}

    def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None, seq_p_group=None, model_cls=None, use_fp8_comm=False):

img_qkv_len should be changed to slice_qkv_len.

laiaqwq avatar Dec 19 '25 04:12 laiaqwq