flash-linear-attention icon indicating copy to clipboard operation
flash-linear-attention copied to clipboard

更新后的rwkv6,loss会nan

Open JL-er opened this issue 1 year ago • 15 comments

我现在用的是前几天的版本loss正常

JL-er avatar May 16 '24 04:05 JL-er

Oh, looks that you may need to switch back to logsigmoid, -exp is not stable yet

yzhangcs avatar May 16 '24 04:05 yzhangcs

image 这是可行的loss非常稳定,基本没有误差

JL-er avatar May 16 '24 06:05 JL-er

image 应该是这次更新的问题

JL-er avatar May 16 '24 06:05 JL-er

This update fixes potential nans during inference, I think it's not the issue. Possibly cuz of potential inf grad of -exp, would check it, thank you

yzhangcs avatar May 16 '24 06:05 yzhangcs

RWKV-PEFT 添加fla,目前是可用的。但是一旦更换新fla loss就会nan,如果后续fla有更新可以告诉我 ,我可以进行测试

JL-er avatar May 16 '24 06:05 JL-er

image 不知道为什么fla的rwkv6,竟然没有cuda快,我之前测试gla的时候会快很多

JL-er avatar May 16 '24 06:05 JL-er

Have you compared the kernel speed

yzhangcs avatar May 16 '24 06:05 yzhangcs

我找时间测一下,对了还有个问题,我在做state tuning的时候,替换上fla算子会出现报错 image 应该是state没有保存梯度的原因,所以想问一下怎么解决?

JL-er avatar May 16 '24 07:05 JL-er

You can enable gradient for h0 mannually

yzhangcs avatar May 16 '24 07:05 yzhangcs

Taking h0 as learnable params would be ok? like h0 = nn.Parameter(key_dim, head_dim)

yzhangcs avatar May 16 '24 07:05 yzhangcs

image image image 我在使用cuda算子时是可以正常运行的,但是fla不行,正常情况state在算子计算的梯度会自动保存

JL-er avatar May 16 '24 07:05 JL-er

还有一点是,我这里冻结了其他所有权重只保留state的梯度

JL-er avatar May 16 '24 07:05 JL-er

ic, currently there is no access to grad of states. we will add an option later

yzhangcs avatar May 16 '24 10:05 yzhangcs

thank you

JL-er avatar May 16 '24 10:05 JL-er

@JL-er Hi, check it out https://github.com/sustcsonglin/flash-linear-attention/commit/1547448b998a163fdb33c49266da699db13f2dc8

Now we do not truncate grad of h states for RWKV6 for ease of state tuning Do contact us if you met any bugs or any numerical stability issues :-D

yzhangcs avatar May 24 '24 15:05 yzhangcs

rwkv-peft上测试非常完美,已经不需要clip了。不过之前infctx训练6000ctx len时偶尔会nan(我会重新测试) 非常感谢您

JL-er avatar May 27 '24 03:05 JL-er

FYI we've recently fixed a bug that causes NaN when log decay is very small. https://github.com/fla-org/flash-linear-attention/issues/77#issuecomment-2585539785

sustcsonglin avatar Jan 12 '25 01:01 sustcsonglin