flash-linear-attention icon indicating copy to clipboard operation
flash-linear-attention copied to clipboard

RWKV6 backward issue

Open ostix360 opened this issue 9 months ago • 9 comments

Hi I catched the 3B version of the model from the hugging face hub and then when I try to use loss.backward (after model.train()) using the transformer library, I got this error providing from your library.

  File "/home/ostix/.virtualenvs/AI-architectures/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 1237, in ast_to_ttir
    raise CompilationError(fn.src, node, repr(e)) from e
triton.compiler.errors.CompilationError: at 61:21:
    mask_bk = i_k * BK + tl.arange(0, BK) < DK
    mask_bv = i_v * BV + tl.arange(0, BV) < DV
    mask_kv = mask_bk[:, None] & mask_bv[None, :]
    _u = tl.load(p_u, mask=mask_bk, other=0).to(tl.float32)
    h = tl.zeros([BV, BK], dtype=tl.float32)
    if USE_INITIAL_STATE:
        p_init_s = initial_state + i_bh * DK * DV + \
            (i_k * BK + tl.arange(0, BK)[None, :]) * \
            DV + (i_v * BV + tl.arange(0, BV)[:, None])
        h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)
                     ^
ValueError('Cannot broadcast, the expanded size of the tensor (64) must match the existing size (16) at non-singleton dimension 0: [16, 64], [64, 16]')

Thanks for considering this issue

ostix360 avatar May 15 '24 17:05 ostix360