ring-flash-attention
ring-flash-attention copied to clipboard
Got error in ZigZagRingFlashAttnVarlenFunc
- It seems the batch dimension will be disappeared after _upad_input function (this function is usually copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input). Then the block_lse obtained from L118 in zigzag_ring_flash_attn_varlen.py only has 2 dimensions (num_head and seq_len). It will cause error in the flatten_varlen_lse function (L120 in zigzag_ring_flash_attn_varlen.py), where the block_lse are required to have three dimensions.
- An illegal memory access error will be reported in the 'else' branch in L135 of zigzag_ring_flash_attn_varlen.py. I can not even print the half_cu_seqlens or cu_seqlens tensor before flatten_varlen_lse function:
File "/mnt/workspace/anaconda3/envs/longva/lib/python3.10/site-packages/ring_flash_attn/zigzag_ring_flash_attn_varlen.py", line 140, in zigzag_ring_flash_attn_varlen_forward
print(cu_seqlens)
File "/mnt/workspace/anaconda3/envs/longva/lib/python3.10/site-packages/torch/_tensor.py", line 431, in __repr__
return torch._tensor_str._str(self, tensor_contents=tensor_contents)
File "/mnt/workspace/anaconda3/envs/longva/lib/python3.10/site-packages/torch/_tensor_str.py", line 664, in _str
return _str_intern(self, tensor_contents=tensor_contents)
File "/mnt/workspace/anaconda3/envs/longva/lib/python3.10/site-packages/torch/_tensor_str.py", line 595, in _str_intern
tensor_str = _tensor_str(self, indent)
File "/mnt/workspace/anaconda3/envs/longva/lib/python3.10/site-packages/torch/_tensor_str.py", line 347, in _tensor_str
formatter = _Formatter(get_summarized_data(self) if summarize else self)
File "/mnt/workspace/anaconda3/envs/longva/lib/python3.10/site-packages/torch/_tensor_str.py", line 133, in __init__
value_str = f"{value}"
File "/mnt/workspace/anaconda3/envs/longva/lib/python3.10/site-packages/torch/_tensor.py", line 933, in __format__
return self.item().__format__(format_spec)
RuntimeError: CUDA error: an illegal memory access was encountered
'''