oneflow icon indicating copy to clipboard operation
oneflow copied to clipboard

Dev wkv

Open zhongshsh opened this issue 1 year ago • 3 comments

目前有个将 RWKV-v4 迁移到 libai 的需求。由于 RWKV-v4 自定义了 cuda 算子 wkv,因此需要将 wkv 迁移成 OneFlow 算子。该 pr 迁移了 v2 版本的 wkv

  • wkv.forward 用于推导 y 的值 -> flow._C.wkv
  • wkv.backward 用于计算参数的梯度 -> flow._C.wkv 的求导算子

对齐脚本:

import numpy as np
import oneflow as flow
import torch
from torch.utils.cpp_extension import load

CUDA_KERNEL_VERSION = 2
wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", f"cuda/wkv_cuda_v{CUDA_KERNEL_VERSION}.cu"],
                  verbose=True, extra_cuda_cflags=['--use_fast_math', '--extra-device-vectorization'])

class WKV(torch.autograd.Function):
    @staticmethod
    def forward(ctx, B, T, C, w, u, k, v):
        ctx.B = B
        ctx.T = T
        ctx.C = C
        w = -torch.exp(w.contiguous())
        u = u.contiguous()
        k = k.contiguous()
        v = v.contiguous()
        ctx.save_for_backward(w, u, k, v)
        y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format)
        wkv_cuda.forward(B, T, C, w, u, k, v, y)
        return y

    @staticmethod
    def backward(ctx, gy):
        B = ctx.B
        T = ctx.T
        C = ctx.C
        w, u, k, v = ctx.saved_tensors
        gw = torch.zeros((B, C), device='cuda')
        gu = torch.zeros((B, C), device='cuda')
        gk = torch.zeros((B, T, C), device='cuda')
        gv = torch.zeros((B, T, C), device='cuda')
        wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv)
        gw = torch.sum(gw, dim=0)
        gu = torch.sum(gu, dim=0)
        return (None, None, None, gw, gu, gk, gv)

def RUN_CUDA(B, T, C, w, u, k, v):
    return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda())

def CHECK_CUDA():
    B = 1
    T = 1
    C = 3

    with torch.no_grad():
        w = torch.zeros(C, requires_grad=True, device='cuda').uniform_(-1, 1)
        u = torch.zeros(C, requires_grad=True, device='cuda').uniform_(-1, 1)
        k = torch.zeros(B, T, C, requires_grad=True, device='cuda').uniform_(-1, 1)
        v = torch.zeros(B, T, C, requires_grad=True, device='cuda').uniform_(-1, 1)

    with torch.autograd.profiler.profile(use_cuda=True) as prof:
        y1 = RUN_CUDA(B, T, C, w, u, k, v)
    loss1 = ((y1 * y1) - torch.tanh(y1)).sum()
    with torch.autograd.profiler.profile(use_cuda=True) as prof:
        loss1.backward()
    
    gw = w.grad
    gu = u.grad
    gk = k.grad
    gv = v.grad
    gw_torch = gw.detach().cpu().numpy()
    gu_torch = gu.detach().cpu().numpy()
    gk_torch = gk.detach().cpu().numpy()
    gv_torch = gv.detach().cpu().numpy()
    
    w = flow.tensor(w.detach().cpu().numpy(), requires_grad=True, device='cuda')
    u = flow.tensor(u.detach().cpu().numpy(), requires_grad=True, device='cuda')
    k = flow.tensor(k.detach().cpu().numpy(), requires_grad=True, device='cuda')
    v = flow.tensor(v.detach().cpu().numpy(), requires_grad=True, device='cuda')
    y2 = flow._C.wkv(B, T, C, w, u, k, v).requires_grad_()
    loss2 = ((y2 * y2) - flow.tanh(y2)).sum()
    loss2.backward()

    gw = w.grad
    gu = u.grad
    gk = k.grad
    gv = v.grad
    gw_flow = gw.detach().cpu().numpy()
    gu_flow = gu.detach().cpu().numpy()
    gk_flow = gk.detach().cpu().numpy()
    gv_flow = gv.detach().cpu().numpy()

    print(np.allclose(gw_flow, gw_torch, atol=1e-5))
    print(np.allclose(gu_flow, gu_torch, atol=1e-5))
    print(np.allclose(gk_flow, gk_torch, atol=1e-5))
    print(np.allclose(gv_flow, gv_torch, atol=1e-5))
    print(gu_flow, gu_torch)

    
if __name__ == "__main__":
    CHECK_CUDA()

zhongshsh avatar Jul 26 '22 09:07 zhongshsh

我倾向不合并到master分支,以这个分支给用户编译使用

MARD1NO avatar Jul 28 '22 06:07 MARD1NO

原始需求:

class WKV(torch.autograd.Function):
    @staticmethod
    def forward(ctx, B, T, C, w, u, k, v):
        ctx.B = B
        ctx.T = T
        ctx.C = C
        assert T <= T_MAX
        assert B * C % min(C, 1024) == 0
        w = -torch.exp(w.float().contiguous())
        u = u.float().contiguous()
        k = k.float().contiguous()
        v = v.float().contiguous()
        ctx.save_for_backward(w, u, k, v)
        y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format)
        wkv_cuda.forward(B, T, C, w, u, k, v, y)
        return y.half()

    @staticmethod
    def backward(ctx, gy):
        B = ctx.B
        T = ctx.T
        C = ctx.C
        assert T <= T_MAX
        assert B * C % min(C, 1024) == 0
        w, u, k, v = ctx.saved_tensors
        gw = torch.zeros((B, C), device='cuda')
        gu = torch.zeros((B, C), device='cuda')
        gk = torch.zeros((B, T, C), device='cuda')
        gv = torch.zeros((B, T, C), device='cuda')
        wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv)
        gw = torch.sum(gw, dim=0)
        gu = torch.sum(gu, dim=0)
        return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())

要把这个类改成oneflow可用的一个算子。实际上本pr实现的2个算子分别对应了一个op的前后向,所以还需要改一下,把wkv_backward op注册为wkv_forward_grad.

BBuf avatar Jul 28 '22 06:07 BBuf

@hjchen2 @guo-ran 我们那个bfp16感觉不能和fp16一起用一个名单维护,我这里发现的一个问题是如果layernorm用fp16计算loss可以对齐,如果用bfp16计算的话loss就会突变,变成这样:

图片

所以fp16和bfp16应该分2个List维护吧,不是适合用fp16做训练的模型也一定适合用bfp16,我目前的做法是先注释掉了gray list的layernorm来保证bf16模式的正确性。

BBuf avatar Aug 18 '22 08:08 BBuf