ring-flash-attention icon indicating copy to clipboard operation
ring-flash-attention copied to clipboard

ring flash attention with BPT

Open JiaoPL opened this issue 2 years ago • 3 comments

Hi~ @zhuzilin 我正在尝试将BPT 接入ring flash attention,使用chunk_size切分qkv,在local进行更小chunk的attention计算。 参照ring_flash_attn.py的forward和backward,实现了 blockwise_flash_attn_forwardblockwise_flash_attn_backward,目前forward精度可以对齐,backward存在误差。我想问一下,backward的实现可能存在哪些问题? 下面是我的实现:

def blockwise_flash_attn_forward(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    q_chunk_size: int,
    k_chunk_size: int,
    softmax_scale,
    dropout_p=0,
    causal=True,
    return_softmax=True
):
    assert q.shape[1] % q_chunk_size == 0
    assert k.shape[1] % k_chunk_size == 0
    
    num_q_chunk = q.shape[1] // q_chunk_size
    num_k_chunk = k.shape[1] // k_chunk_size
    batch,seqlen,head_dim,num_head = q.shape
    
    block_out = torch.empty(q.shape, dtype=torch.float32, device=q.device)
    block_lse = torch.empty((batch,seqlen,head_dim,1), dtype=torch.float32, device=q.device)

    for i in range(num_q_chunk):
        q_i = q[:,i * q_chunk_size: (i + 1) * q_chunk_size]
        out_i = None
        lse_i = None
        
        for j in range(num_k_chunk-1,-1,-1):
            if j > i and causal:
                continue
            
            k_j = k[:,j * k_chunk_size: (j + 1) * k_chunk_size]
            v_j = v[:,j * k_chunk_size: (j + 1) * k_chunk_size]
            
            out_ij, _, _, _, _, lse_ij, _, _ = _flash_attn_forward(
                q_i,
                k_j,
                v_j,
                dropout_p,
                softmax_scale,
                causal=causal and j == i,
                return_softmax=return_softmax and dropout_p > 0
            )
            out_i, lse_i = update_out_and_lse(out_i, lse_i, out_ij, lse_ij)

        block_out[:, i * q_chunk_size: (i + 1) * q_chunk_size] = out_i
        block_lse[:, i * q_chunk_size: (i + 1) * q_chunk_size] = lse_i
        
    return block_out, block_lse.squeeze(dim=-1).transpose(-1,-2)


def blockwise_flash_attn_backward(
    dout,
    q,
    k,
    v,
    out,
    q_chunk_size,
    k_chunk_size,
    softmax_lse,
    dq,
    dk,
    dv,
    softmax_scale,
    dropout_p,
    causal=True,
    rng_state=None
):
    assert q.shape[1] % q_chunk_size == 0
    assert k.shape[1] % k_chunk_size == 0

    num_q_chunk = q.shape[1] // q_chunk_size
    num_k_chunk = k.shape[1] // k_chunk_size

    temp_dq_buffer = torch.empty(q[:,:q_chunk_size].shape, dtype=q.dtype, device=q.device)
    temp_dk_buffer = torch.empty(k[:,:k_chunk_size].shape, dtype=k.dtype, device=k.device)
    temp_dv_buffer = torch.empty(v[:,:k_chunk_size].shape, dtype=v.dtype, device=v.device)
    
    
    for i in range(num_q_chunk):
        q_i = q[:,i * q_chunk_size: (i + 1) * q_chunk_size]
        dout_i = dout[:,i * q_chunk_size: (i + 1) * q_chunk_size]
        out_i = out[:,i * q_chunk_size: (i + 1) * q_chunk_size]
        softmax_lse_i = softmax_lse[:,:,i * q_chunk_size: (i + 1) * q_chunk_size]
        q_i = q_i.contiguous()
        dout_i = dout_i.contiguous()
        out_i = out_i.contiguous()
        softmax_lse_i = softmax_lse_i.contiguous()

        for j in range(num_k_chunk):
            k_j = k[:,j * k_chunk_size: (j + 1) * k_chunk_size]
            v_j = v[:,j * k_chunk_size: (j + 1) * k_chunk_size]
            k_j = k_j.contiguous()
            v_j = v_j.contiguous()

            if j > i and causal:
                continue

            _flash_attn_backward(
                dout_i,
                q_i,
                k_j,
                v_j,
                out_i,
                softmax_lse_i,
                temp_dq_buffer,
                temp_dk_buffer,
                temp_dv_buffer,
                dropout_p,
                softmax_scale,
                causal = causal and j == i,
                rng_state=rng_state,
            )
            
            # update dq dk dv
            dq[:,i * q_chunk_size: (i + 1) * q_chunk_size] += temp_dq_buffer
            dk[:,j * k_chunk_size: (j + 1) * k_chunk_size] += temp_dk_buffer
            dv[:,j * k_chunk_size: (j + 1) * k_chunk_size] += temp_dv_buffer

分别替换ring_flash_attn_forward 中的_flash_attn_forward,和ring_flash_attn_backward中的_flash_attn_backward

下面是我的测试结果:

##############################
# forward:
##############################
out: max 2.896484375, mean 0.0203094482421875
lse: max 10.417832374572754, mean 9.204237937927246
out diff:
[0] max 0.00048828125, mean 8.881092071533203e-06
[1] max 0.0001220703125, mean 7.450580596923828e-06
[2] max 0.0001220703125, mean 5.9604644775390625e-06
[3] max 6.103515625e-05, mean 5.066394805908203e-06
[4] max 6.103515625e-05, mean 4.5299530029296875e-06
[5] max 6.103515625e-05, mean 4.112720489501953e-06
[6] max 6.103515625e-05, mean 3.814697265625e-06
[7] max 6.103515625e-05, mean 3.516674041748047e-06
lse diff:
[0] max 9.5367431640625e-07, mean 1.645181413323371e-07
[1] max 9.5367431640625e-07, mean 2.641230878452916e-07
[2] max 1.9073486328125e-06, mean 3.0044466825529526e-07
[3] max 1.9073486328125e-06, mean 3.3890827921823075e-07
[4] max 1.9073486328125e-06, mean 3.8137659430503845e-07
[5] max 1.9073486328125e-06, mean 4.0913002408160537e-07
[6] max 1.9073486328125e-06, mean 4.272908142866072e-07
[7] max 1.9073486328125e-06, mean 4.6798959374427795e-07
##############################
# backward:
##############################
load_dq:
[0] max 2.783203125, mean 0.052520751953125
[1] max 0.3310546875, mean 0.02398681640625
[2] max 0.2083740234375, mean 0.0184478759765625
[3] max 0.1162109375, mean 0.0155792236328125
[4] max 0.13330078125, mean 0.01374053955078125
[5] max 0.1204833984375, mean 0.01241302490234375
[6] max 0.11260986328125, mean 0.0114288330078125
[7] max 0.0775146484375, mean 0.01064300537109375
dq diff:
[0] max 0.005859375, mean 7.49826431274414e-05
[1] max 0.186279296875, mean 0.01239776611328125
[2] max 0.1973876953125, mean 0.01953125
[3] max 0.235107421875, mean 0.0253143310546875
[4] max 0.30615234375, mean 0.0301361083984375
[5] max 0.52392578125, mean 0.03436279296875
[6] max 0.56689453125, mean 0.038177490234375
[7] max 0.3955078125, mean 0.041748046875
load_dk:
[0] max 2.654296875, mean 0.05340576171875
[1] max 0.256591796875, mean 0.021697998046875
[2] max 0.169921875, mean 0.01535797119140625
[3] max 0.13330078125, mean 0.0116729736328125
[4] max 0.09124755859375, mean 0.0090484619140625
[5] max 0.1158447265625, mean 0.006908416748046875
[6] max 0.050384521484375, mean 0.00492095947265625
[7] max 0.03936767578125, mean 0.002498626708984375
dk diff:
[0] max 0.253173828125, mean 0.03192138671875
[1] max 0.16845703125, mean 0.0232696533203125
[2] max 0.130126953125, mean 0.017364501953125
[3] max 0.1097412109375, mean 0.012786865234375
[4] max 0.10797119140625, mean 0.00893402099609375
[5] max 0.049530029296875, mean 0.005580902099609375
[6] max 0.039337158203125, mean 0.002498626708984375
[7] max 1.52587890625e-05, mean 3.5762786865234375e-07
load_dv:
[0] max 5.89453125, mean 0.05450439453125
[1] max 0.1951904296875, mean 0.021484375
[2] max 0.11883544921875, mean 0.01525115966796875
[3] max 0.10003662109375, mean 0.01158905029296875
[4] max 0.07550048828125, mean 0.00901031494140625
[5] max 0.06658935546875, mean 0.006816864013671875
[6] max 0.041015625, mean 0.00492095947265625
[7] max 0.041961669921875, mean 0.002475738525390625
dv diff:
[0] max 0.3232421875, mean 0.042572021484375
[1] max 0.21240234375, mean 0.03094482421875
[2] max 0.1527099609375, mean 0.0223236083984375
[3] max 0.1075439453125, mean 0.015625
[4] max 0.08245849609375, mean 0.010223388671875
[5] max 0.0447998046875, mean 0.005950927734375
[6] max 0.0419921875, mean 0.002475738525390625
[7] max 3.0517578125e-05, mean 3.5762786865234375e-07

JiaoPL avatar Mar 22 '24 09:03 JiaoPL

Do we even still need the BPT if we have the ring attention implemented in this repo? @zhuzilin

I personally think BPT is a single-GPU version of ring attention, right?

GeneZC avatar Mar 24 '24 04:03 GeneZC

Do we even still need the BPT if we have the ring attention implemented in this repo? @zhuzilin

I personally think BPT is a single-GPU version of ring attention, right?

That's right, BPT is inherently supported by ring attention. We do not need another implementation. image

Edenzzzz avatar Mar 28 '24 05:03 Edenzzzz

I'm not sure spliting the sequence length on each device into blocks could save memory (because we still need save buffers and flash_attn itself seems to use linear size memory w.r.t. sequence length), or speed up (because it will call smaller kernels).

zhuzilin avatar Apr 18 '24 08:04 zhuzilin