ring-flash-attention icon indicating copy to clipboard operation
ring-flash-attention copied to clipboard

Got error in ZigZagRingFlashAttnVarlenFunc

Open ThisisBillhe opened this issue 5 months ago • 4 comments

  1. It seems the batch dimension will be disappeared after _upad_input function (this function is usually copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input). Then the block_lse obtained from L118 in zigzag_ring_flash_attn_varlen.py only has 2 dimensions (num_head and seq_len). It will cause error in the flatten_varlen_lse function (L120 in zigzag_ring_flash_attn_varlen.py), where the block_lse are required to have three dimensions.
  2. An illegal memory access error will be reported in the 'else' branch in L135 of zigzag_ring_flash_attn_varlen.py. I can not even print the half_cu_seqlens or cu_seqlens tensor before flatten_varlen_lse function:
  File "/mnt/workspace/anaconda3/envs/longva/lib/python3.10/site-packages/ring_flash_attn/zigzag_ring_flash_attn_varlen.py", line 140, in zigzag_ring_flash_attn_varlen_forward
    print(cu_seqlens)
  File "/mnt/workspace/anaconda3/envs/longva/lib/python3.10/site-packages/torch/_tensor.py", line 431, in __repr__
    return torch._tensor_str._str(self, tensor_contents=tensor_contents)
  File "/mnt/workspace/anaconda3/envs/longva/lib/python3.10/site-packages/torch/_tensor_str.py", line 664, in _str
    return _str_intern(self, tensor_contents=tensor_contents)
  File "/mnt/workspace/anaconda3/envs/longva/lib/python3.10/site-packages/torch/_tensor_str.py", line 595, in _str_intern
    tensor_str = _tensor_str(self, indent)
  File "/mnt/workspace/anaconda3/envs/longva/lib/python3.10/site-packages/torch/_tensor_str.py", line 347, in _tensor_str
    formatter = _Formatter(get_summarized_data(self) if summarize else self)
  File "/mnt/workspace/anaconda3/envs/longva/lib/python3.10/site-packages/torch/_tensor_str.py", line 133, in __init__
    value_str = f"{value}"
  File "/mnt/workspace/anaconda3/envs/longva/lib/python3.10/site-packages/torch/_tensor.py", line 933, in __format__
    return self.item().__format__(format_spec)
RuntimeError: CUDA error: an illegal memory access was encountered
'''

ThisisBillhe avatar Sep 03 '24 03:09 ThisisBillhe