ring-flash-attention
ring-flash-attention copied to clipboard
Error when increasing the sequence length
Hi, when I increase the seqlen from 1024 * 8 to 1024 * 64 here:
https://github.com/zhuzilin/ring-flash-attention/blob/9e2a7e543d6461cc935d44142fc99660de7b8579/benchmark/benchmark_varlen_qkvpacked_func.py#L18
Then, I run the code with
torchrun benchmark/benchmark_varlen_qkvpacked_func.py
The program starts to error, and the error log is as follows:
# flash_attn_varlen_qkvpacked_func
329.0089328816957 iter/s, 0.303943115234375 sec
# ring_flash_attn_varlen_qkvpacked_func
125.49088812377029 iter/s, 0.79687060546875 sec
# zigzag_ring_flash_attn_varlen_qkvpacked_func
[rank0]: Traceback (most recent call last):
[rank0]: File "/data/zecheng/lcm_stack/ring-flash-attention/benchmark/benchmark_varlen_qkvpacked_func.py", line 99, in <module>
[rank0]: benchmark(f, forward_only=forward_only, log=False)
[rank0]: File "/data/zecheng/lcm_stack/ring-flash-attention/benchmark/benchmark_varlen_qkvpacked_func.py", line 64, in benchmark
[rank0]: out = f(
[rank0]: File "/data/zecheng/lcm_stack/ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn_varlen.py", line 413, in zigzag_ring_flash_attn_varlen_qkvpacked_func
[rank0]: return ZigZagRingFlashAttnVarlenFunc.apply(
[rank0]: File "/data/anaconda3/envs/new_zecheng/lib/python3.10/site-packages/torch/autograd/function.py", line 598, in apply
[rank0]: return super().apply(*args, **kwargs) # type: ignore[misc]
[rank0]: File "/data/zecheng/lcm_stack/ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn_varlen.py", line 331, in forward
[rank0]: out, softmax_lse = zigzag_ring_flash_attn_varlen_forward(
[rank0]: File "/data/zecheng/lcm_stack/ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn_varlen.py", line 94, in zigzag_ring_flash_attn_varlen_forward
[rank0]: q1 = q[half_index1]
[rank0]: IndexError: The shape of the mask [8192] at index 0 does not match the shape of the indexed tensor [65536, 5, 128] at index 0
E1003 16:05:54.677000 139778620753728 torch/distributed/elastic/multiprocessing/api.py:826] failed (exitcode: 1) local_rank: 0 (pid: 140219) of binary: /data/anaconda3/envs/new_zecheng/bin/python3.10
How to fix this problem?