flash-linear-attention
flash-linear-attention copied to clipboard
RWKV6 backward issue
Hi I catched the 3B version of the model from the hugging face hub and then when I try to use loss.backward (after model.train()) using the transformer library, I got this error providing from your library.
File "/home/ostix/.virtualenvs/AI-architectures/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 1237, in ast_to_ttir
raise CompilationError(fn.src, node, repr(e)) from e
triton.compiler.errors.CompilationError: at 61:21:
mask_bk = i_k * BK + tl.arange(0, BK) < DK
mask_bv = i_v * BV + tl.arange(0, BV) < DV
mask_kv = mask_bk[:, None] & mask_bv[None, :]
_u = tl.load(p_u, mask=mask_bk, other=0).to(tl.float32)
h = tl.zeros([BV, BK], dtype=tl.float32)
if USE_INITIAL_STATE:
p_init_s = initial_state + i_bh * DK * DV + \
(i_k * BK + tl.arange(0, BK)[None, :]) * \
DV + (i_v * BV + tl.arange(0, BV)[:, None])
h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)
^
ValueError('Cannot broadcast, the expanded size of the tensor (64) must match the existing size (16) at non-singleton dimension 0: [16, 64], [64, 16]')
Thanks for considering this issue
@ostix360 could you provide minimal reproducible code snippets? the inputs to chunk_rwkv6
The code in hugging doesn't use chunk_rwkv6 but recurrent_fuse. here is the function that is called during the forward pass:
def rwkv6_linear_attention(
training,
receptance,
key,
value,
time_decay,
time_first,
state,
):
no_cuda = any(t.device.type != "cuda" for t in [time_decay, time_first, receptance, key, value])
# Launching the CUDA kernel for just one token will actually be slower (there is no for loop in the CPU version
# in this case).
one_token = key.size(1) == 1
if not training or no_cuda or one_token:
return rwkv6_linear_attention_cpu( # Not called
receptance, key, value, time_decay, time_first, state
)
else:
batch, seq_length, _ = receptance.shape
num_heads, head_size = time_first.shape
key = key.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2) # B, T, H, K -> B, H, T, K
value = value.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2) # B, T, H, K - > B, H, T, V
receptance = receptance.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2) # B, H, T, K
time_decay = -torch.exp(time_decay.float()).view(batch, seq_length, num_heads, head_size).permute(0, 2, 1, 3) # B, T, H, K -> B, H, T, K
time_first = time_first.float().reshape(num_heads, head_size) # H, K
out, state = fused_recurrent_rwkv6(receptance, key, value, time_decay, time_first, scale=1.0, initial_state=state, output_final_state=True)
return out.transpose(1, 2), state
Is this what you want? if you want the input of this function I can give you what I try with lower dimension by debugging the code.
@ostix360 Hello, sorry for late reply, just ran the examples, passed seamlessly. Could you share your triton version & hardware infos?
Hi, Triton 2.2.0 cuda version 12.4 torch 2.2.2+cu121 GPU 4070-ti (12GO VRAM) cpu amd ryzen 5 3600 running on WSL 2
@ostix360 I'm not sure if 4070-ti would be ok. Does other kernels work for you, e.g., chunk_gla.
If neither, I think triton 2.2 is not compatible with 4070-ti for current fla
Does chunk_rwkv6 work for you?
Could you please provide me a code that I can test?
https://github.com/sustcsonglin/flash-linear-attention/blob/main/tests/ops/test_gla.py
same level folder as fla
$ pytest -s test_gla
for gla
RWKV6 works similarly
you can also simply run the checks
$ python -m fla.ops.rwkv6.chunk_naive
I got 6 assertion errors for test_gla
FAILED tests/ops/test_gla.py::test_fused_chunk[dtype0-32-300-4-4] - AssertionError: dg diff: 0.17578125
FAILED tests/ops/test_gla.py::test_fused_chunk[dtype0-32-512-4-4] - AssertionError: dg diff: 0.2265625
FAILED tests/ops/test_gla.py::test_fused_chunk[dtype0-64-300-4-4] - AssertionError: dg diff: 0.1630859375
FAILED tests/ops/test_gla.py::test_fused_chunk[dtype0-64-512-4-4] - AssertionError: dg diff: 0.22265625
FAILED tests/ops/test_gla.py::test_fused_chunk[dtype0-128-300-4-4] - AssertionError: dg diff: 0.189453125
FAILED tests/ops/test_gla.py::test_fused_chunk[dtype0-128-512-4-4] - AssertionError: dg diff: 0.234375
For the test of chunk native I got:
tensor(146., device='cuda:0', dtype=torch.bfloat16, grad_fn=<MaxBackward1>)
tensor(146., device='cuda:0', dtype=torch.bfloat16, grad_fn=<MaxBackward1>)
tensor(197., device='cuda:0', dtype=torch.bfloat16)
tensor(221., device='cuda:0', dtype=torch.bfloat16)
tensor(54.7500, device='cuda:0', dtype=torch.bfloat16)
tensor(161., device='cuda:0', dtype=torch.bfloat16)
tensor(1736., device='cuda:0', dtype=torch.bfloat16)
Thx for your help
@ostix360 Hello, could you check it again. Just pushed some commits to fix some issues of initial states.
Yes it works fine Thanks for your help
My bad there are NaN in the model. I think because the grad norm is not NaN
It seems that this NaN come from the backpropagation
Hello, Can you provide a code snippet that goes NaN?
Hi Sorry it was a weight initialisation problem. Everything works fine!