flash-linear-attention
flash-linear-attention copied to clipboard
[Bug] RoPE attention encounters illegal memory for long sequence decoding
Checklist
- [x] I have checked FAQs and existing issues for similar problems
- [x] My GPU is H100 and I have installed
triton-nightlybuilt by fla team, and double checked FAQs - [x] Please report this bug in English to ensure wider understanding and support
Describe the Bug
File "/data/cl/user/yangsl66/miniconda3/envs/fla/lib/python3.12/site-packages/fla/models/transformer/modeling_transformer.py", line 76, in forward
hidden_states, attentions, past_key_values = self.attn(
^^^^^^^^^^
File "/data/cl/user/yangsl66/miniconda3/envs/fla/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/cl/user/yangsl66/miniconda3/envs/fla/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/cl/user/yangsl66/miniconda3/envs/fla/lib/python3.12/site-packages/fla/layers/attn.py", line 114, in forward
max_seqlen = q.shape[1] + max(seqlen_offset)
^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Steps to Reproduce the Bug
N/A
Expected Behavior
N/A
Environment Information
N/A
I think it's because tl.int32? How could I reproduce this error
I might have also encountered this problem. From my test it happens at the 2nd transformer layer during decoding with H100 GPUs, length >= 256, and with negative seqlen_offset values (e.g. when attention_mask is used and the cache is enabled). It works if it is on Ampere GPUs, or with smaller length, or with only non-negative offsets, or the RoPE function is not called in an attention layer.
To reproduce the error, a minimal script would be
import fla
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
name = 'fla-hub/transformer-340M-10B'
tokenizer = AutoTokenizer.from_pretrained(name)
model = AutoModelForCausalLM.from_pretrained(name)
model = model.cuda()
model = model.bfloat16()
seq_lengths = [256, 257]
B = len(seq_lengths)
L = max(seq_lengths)
input_ids = torch.ones((B, L), dtype=torch.int64).cuda()
attention_mask = torch.zeros((B, L), dtype=torch.int64).cuda()
for i in range(B):
attention_mask[i, -seq_lengths[i]:] = 1
r = model.generate(input_ids=input_ids, attention_mask=attention_mask, use_cache=True)
I'm using the latest fla (2816af4), pytorch 2.7.0, and triton 3.3.0.