flash-linear-attention icon indicating copy to clipboard operation
flash-linear-attention copied to clipboard

[Bug] RoPE attention encounters illegal memory for long sequence decoding

Open sustcsonglin opened this issue 6 months ago • 1 comments

Checklist

  • [x] I have checked FAQs and existing issues for similar problems
  • [x] My GPU is H100 and I have installed triton-nightly built by fla team, and double checked FAQs
  • [x] Please report this bug in English to ensure wider understanding and support

Describe the Bug

 File "/data/cl/user/yangsl66/miniconda3/envs/fla/lib/python3.12/site-packages/fla/models/transformer/modeling_transformer.py", line 76, in forward
    hidden_states, attentions, past_key_values = self.attn(
                                                 ^^^^^^^^^^
  File "/data/cl/user/yangsl66/miniconda3/envs/fla/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/cl/user/yangsl66/miniconda3/envs/fla/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/cl/user/yangsl66/miniconda3/envs/fla/lib/python3.12/site-packages/fla/layers/attn.py", line 114, in forward
    max_seqlen = q.shape[1] + max(seqlen_offset)
                              ^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Steps to Reproduce the Bug

N/A

Expected Behavior

N/A

Environment Information

N/A

sustcsonglin avatar May 22 '25 01:05 sustcsonglin

I think it's because tl.int32? How could I reproduce this error

zhiyuan1i avatar May 23 '25 02:05 zhiyuan1i

I might have also encountered this problem. From my test it happens at the 2nd transformer layer during decoding with H100 GPUs, length >= 256, and with negative seqlen_offset values (e.g. when attention_mask is used and the cache is enabled). It works if it is on Ampere GPUs, or with smaller length, or with only non-negative offsets, or the RoPE function is not called in an attention layer.

To reproduce the error, a minimal script would be

import fla
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

name = 'fla-hub/transformer-340M-10B'
tokenizer = AutoTokenizer.from_pretrained(name)
model = AutoModelForCausalLM.from_pretrained(name)
model = model.cuda()
model = model.bfloat16()

seq_lengths = [256, 257]
B = len(seq_lengths)
L = max(seq_lengths)
input_ids = torch.ones((B, L), dtype=torch.int64).cuda()
attention_mask = torch.zeros((B, L), dtype=torch.int64).cuda()
for i in range(B):
    attention_mask[i, -seq_lengths[i]:] = 1

r = model.generate(input_ids=input_ids, attention_mask=attention_mask, use_cache=True)

I'm using the latest fla (2816af4), pytorch 2.7.0, and triton 3.3.0.

mutiann avatar Jun 03 '25 14:06 mutiann