oneflow
oneflow copied to clipboard
Dev wkv
目前有个将 RWKV-v4 迁移到 libai 的需求。由于 RWKV-v4 自定义了 cuda 算子 wkv,因此需要将 wkv 迁移成 OneFlow 算子。该 pr 迁移了 v2 版本的 wkv。
对齐脚本:
import numpy as np
import oneflow as flow
import torch
from torch.utils.cpp_extension import load
CUDA_KERNEL_VERSION = 2
wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", f"cuda/wkv_cuda_v{CUDA_KERNEL_VERSION}.cu"],
verbose=True, extra_cuda_cflags=['--use_fast_math', '--extra-device-vectorization'])
class WKV(torch.autograd.Function):
@staticmethod
def forward(ctx, B, T, C, w, u, k, v):
ctx.B = B
ctx.T = T
ctx.C = C
w = -torch.exp(w.contiguous())
u = u.contiguous()
k = k.contiguous()
v = v.contiguous()
ctx.save_for_backward(w, u, k, v)
y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format)
wkv_cuda.forward(B, T, C, w, u, k, v, y)
return y
@staticmethod
def backward(ctx, gy):
B = ctx.B
T = ctx.T
C = ctx.C
w, u, k, v = ctx.saved_tensors
gw = torch.zeros((B, C), device='cuda')
gu = torch.zeros((B, C), device='cuda')
gk = torch.zeros((B, T, C), device='cuda')
gv = torch.zeros((B, T, C), device='cuda')
wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv)
gw = torch.sum(gw, dim=0)
gu = torch.sum(gu, dim=0)
return (None, None, None, gw, gu, gk, gv)
def RUN_CUDA(B, T, C, w, u, k, v):
return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda())
def CHECK_CUDA():
B = 1
T = 1
C = 3
with torch.no_grad():
w = torch.zeros(C, requires_grad=True, device='cuda').uniform_(-1, 1)
u = torch.zeros(C, requires_grad=True, device='cuda').uniform_(-1, 1)
k = torch.zeros(B, T, C, requires_grad=True, device='cuda').uniform_(-1, 1)
v = torch.zeros(B, T, C, requires_grad=True, device='cuda').uniform_(-1, 1)
with torch.autograd.profiler.profile(use_cuda=True) as prof:
y1 = RUN_CUDA(B, T, C, w, u, k, v)
loss1 = ((y1 * y1) - torch.tanh(y1)).sum()
with torch.autograd.profiler.profile(use_cuda=True) as prof:
loss1.backward()
gw = w.grad
gu = u.grad
gk = k.grad
gv = v.grad
gw_torch = gw.detach().cpu().numpy()
gu_torch = gu.detach().cpu().numpy()
gk_torch = gk.detach().cpu().numpy()
gv_torch = gv.detach().cpu().numpy()
w = flow.tensor(w.detach().cpu().numpy(), requires_grad=True, device='cuda')
u = flow.tensor(u.detach().cpu().numpy(), requires_grad=True, device='cuda')
k = flow.tensor(k.detach().cpu().numpy(), requires_grad=True, device='cuda')
v = flow.tensor(v.detach().cpu().numpy(), requires_grad=True, device='cuda')
y2 = flow._C.wkv(B, T, C, w, u, k, v).requires_grad_()
loss2 = ((y2 * y2) - flow.tanh(y2)).sum()
loss2.backward()
gw = w.grad
gu = u.grad
gk = k.grad
gv = v.grad
gw_flow = gw.detach().cpu().numpy()
gu_flow = gu.detach().cpu().numpy()
gk_flow = gk.detach().cpu().numpy()
gv_flow = gv.detach().cpu().numpy()
print(np.allclose(gw_flow, gw_torch, atol=1e-5))
print(np.allclose(gu_flow, gu_torch, atol=1e-5))
print(np.allclose(gk_flow, gk_torch, atol=1e-5))
print(np.allclose(gv_flow, gv_torch, atol=1e-5))
print(gu_flow, gu_torch)
if __name__ == "__main__":
CHECK_CUDA()
我倾向不合并到master分支,以这个分支给用户编译使用
原始需求:
class WKV(torch.autograd.Function):
@staticmethod
def forward(ctx, B, T, C, w, u, k, v):
ctx.B = B
ctx.T = T
ctx.C = C
assert T <= T_MAX
assert B * C % min(C, 1024) == 0
w = -torch.exp(w.float().contiguous())
u = u.float().contiguous()
k = k.float().contiguous()
v = v.float().contiguous()
ctx.save_for_backward(w, u, k, v)
y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format)
wkv_cuda.forward(B, T, C, w, u, k, v, y)
return y.half()
@staticmethod
def backward(ctx, gy):
B = ctx.B
T = ctx.T
C = ctx.C
assert T <= T_MAX
assert B * C % min(C, 1024) == 0
w, u, k, v = ctx.saved_tensors
gw = torch.zeros((B, C), device='cuda')
gu = torch.zeros((B, C), device='cuda')
gk = torch.zeros((B, T, C), device='cuda')
gv = torch.zeros((B, T, C), device='cuda')
wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv)
gw = torch.sum(gw, dim=0)
gu = torch.sum(gu, dim=0)
return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
要把这个类改成oneflow可用的一个算子。实际上本pr实现的2个算子分别对应了一个op的前后向,所以还需要改一下,把wkv_backward op注册为wkv_forward_grad.
@hjchen2 @guo-ran 我们那个bfp16感觉不能和fp16一起用一个名单维护,我这里发现的一个问题是如果layernorm用fp16计算loss可以对齐,如果用bfp16计算的话loss就会突变,变成这样:
所以fp16和bfp16应该分2个List维护吧,不是适合用fp16做训练的模型也一定适合用bfp16,我目前的做法是先注释掉了gray list的layernorm来保证bf16模式的正确性。