flash-linear-attention
flash-linear-attention copied to clipboard
RWKV6 backward issue
Hi I catched the 3B version of the model from the hugging face hub and then when I try to use loss.backward (after model.train()) using the transformer library, I got this error providing from your library.
File "/home/ostix/.virtualenvs/AI-architectures/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 1237, in ast_to_ttir
raise CompilationError(fn.src, node, repr(e)) from e
triton.compiler.errors.CompilationError: at 61:21:
mask_bk = i_k * BK + tl.arange(0, BK) < DK
mask_bv = i_v * BV + tl.arange(0, BV) < DV
mask_kv = mask_bk[:, None] & mask_bv[None, :]
_u = tl.load(p_u, mask=mask_bk, other=0).to(tl.float32)
h = tl.zeros([BV, BK], dtype=tl.float32)
if USE_INITIAL_STATE:
p_init_s = initial_state + i_bh * DK * DV + \
(i_k * BK + tl.arange(0, BK)[None, :]) * \
DV + (i_v * BV + tl.arange(0, BV)[:, None])
h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)
^
ValueError('Cannot broadcast, the expanded size of the tensor (64) must match the existing size (16) at non-singleton dimension 0: [16, 64], [64, 16]')
Thanks for considering this issue