flash-linear-attention icon indicating copy to clipboard operation
flash-linear-attention copied to clipboard

RWKV6 backward issue

Open ostix360 opened this issue 1 year ago • 9 comments

Hi I catched the 3B version of the model from the hugging face hub and then when I try to use loss.backward (after model.train()) using the transformer library, I got this error providing from your library.

  File "/home/ostix/.virtualenvs/AI-architectures/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 1237, in ast_to_ttir
    raise CompilationError(fn.src, node, repr(e)) from e
triton.compiler.errors.CompilationError: at 61:21:
    mask_bk = i_k * BK + tl.arange(0, BK) < DK
    mask_bv = i_v * BV + tl.arange(0, BV) < DV
    mask_kv = mask_bk[:, None] & mask_bv[None, :]
    _u = tl.load(p_u, mask=mask_bk, other=0).to(tl.float32)
    h = tl.zeros([BV, BK], dtype=tl.float32)
    if USE_INITIAL_STATE:
        p_init_s = initial_state + i_bh * DK * DV + \
            (i_k * BK + tl.arange(0, BK)[None, :]) * \
            DV + (i_v * BV + tl.arange(0, BV)[:, None])
        h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)
                     ^
ValueError('Cannot broadcast, the expanded size of the tensor (64) must match the existing size (16) at non-singleton dimension 0: [16, 64], [64, 16]')

Thanks for considering this issue

ostix360 avatar May 15 '24 17:05 ostix360

@ostix360 could you provide minimal reproducible code snippets? the inputs to chunk_rwkv6

yzhangcs avatar May 15 '24 18:05 yzhangcs

The code in hugging doesn't use chunk_rwkv6 but recurrent_fuse. here is the function that is called during the forward pass:

def rwkv6_linear_attention(
    training,
    receptance,
    key,
    value,
    time_decay,
    time_first,
    state,
):
    no_cuda = any(t.device.type != "cuda" for t in [time_decay, time_first, receptance, key, value])
    # Launching the CUDA kernel for just one token will actually be slower (there is no for loop in the CPU version
    # in this case).
    one_token = key.size(1) == 1
    if not training or no_cuda or one_token:
        return rwkv6_linear_attention_cpu(         # Not called 
            receptance, key, value, time_decay, time_first, state
        )
    else:
        batch, seq_length, _ = receptance.shape
        num_heads, head_size = time_first.shape
        key = key.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2) # B, T, H, K -> B, H, T, K
        value = value.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2) # B, T, H, K - > B, H, T, V
        receptance = receptance.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2) # B, H, T, K
        time_decay = -torch.exp(time_decay.float()).view(batch, seq_length, num_heads, head_size).permute(0, 2, 1, 3) # B, T, H, K -> B, H, T, K
        time_first = time_first.float().reshape(num_heads, head_size) # H, K
        out, state = fused_recurrent_rwkv6(receptance, key, value, time_decay, time_first, scale=1.0, initial_state=state, output_final_state=True)
        return out.transpose(1, 2), state

Is this what you want? if you want the input of this function I can give you what I try with lower dimension by debugging the code.

ostix360 avatar May 15 '24 19:05 ostix360

@ostix360 Hello, sorry for late reply, just ran the examples, passed seamlessly. Could you share your triton version & hardware infos?

yzhangcs avatar May 17 '24 07:05 yzhangcs

Hi, Triton 2.2.0 cuda version 12.4 torch 2.2.2+cu121 GPU 4070-ti (12GO VRAM) cpu amd ryzen 5 3600 running on WSL 2

ostix360 avatar May 17 '24 09:05 ostix360

@ostix360 I'm not sure if 4070-ti would be ok. Does other kernels work for you, e.g., chunk_gla. If neither, I think triton 2.2 is not compatible with 4070-ti for current fla

yzhangcs avatar May 17 '24 09:05 yzhangcs

Does chunk_rwkv6 work for you?

yzhangcs avatar May 17 '24 09:05 yzhangcs

Could you please provide me a code that I can test?

ostix360 avatar May 17 '24 09:05 ostix360

https://github.com/sustcsonglin/flash-linear-attention/blob/main/tests/ops/test_gla.py

same level folder as fla

$ pytest -s test_gla

for gla

RWKV6 works similarly

you can also simply run the checks

$ python -m fla.ops.rwkv6.chunk_naive

yzhangcs avatar May 17 '24 09:05 yzhangcs

I got 6 assertion errors for test_gla

FAILED tests/ops/test_gla.py::test_fused_chunk[dtype0-32-300-4-4] - AssertionError: dg diff: 0.17578125
FAILED tests/ops/test_gla.py::test_fused_chunk[dtype0-32-512-4-4] - AssertionError: dg diff: 0.2265625
FAILED tests/ops/test_gla.py::test_fused_chunk[dtype0-64-300-4-4] - AssertionError: dg diff: 0.1630859375
FAILED tests/ops/test_gla.py::test_fused_chunk[dtype0-64-512-4-4] - AssertionError: dg diff: 0.22265625
FAILED tests/ops/test_gla.py::test_fused_chunk[dtype0-128-300-4-4] - AssertionError: dg diff: 0.189453125
FAILED tests/ops/test_gla.py::test_fused_chunk[dtype0-128-512-4-4] - AssertionError: dg diff: 0.234375

For the test of chunk native I got:

tensor(146., device='cuda:0', dtype=torch.bfloat16, grad_fn=<MaxBackward1>)
tensor(146., device='cuda:0', dtype=torch.bfloat16, grad_fn=<MaxBackward1>)
tensor(197., device='cuda:0', dtype=torch.bfloat16)
tensor(221., device='cuda:0', dtype=torch.bfloat16)
tensor(54.7500, device='cuda:0', dtype=torch.bfloat16)
tensor(161., device='cuda:0', dtype=torch.bfloat16)
tensor(1736., device='cuda:0', dtype=torch.bfloat16)

Thx for your help

ostix360 avatar May 18 '24 08:05 ostix360

@ostix360 Hello, could you check it again. Just pushed some commits to fix some issues of initial states.

yzhangcs avatar May 25 '24 12:05 yzhangcs

Yes it works fine Thanks for your help

ostix360 avatar May 27 '24 09:05 ostix360

My bad there are NaN in the model. I think because the grad norm is not NaN

ostix360 avatar May 27 '24 09:05 ostix360

It seems that this NaN come from the backpropagation

ostix360 avatar May 27 '24 09:05 ostix360

Hello, Can you provide a code snippet that goes NaN?

sustcsonglin avatar May 28 '24 07:05 sustcsonglin

Hi Sorry it was a weight initialisation problem. Everything works fine!

ostix360 avatar May 28 '24 09:05 ostix360