ring flash attention with BPT
Hi~ @zhuzilin
我正在尝试将BPT 接入ring flash attention,使用chunk_size切分qkv,在local进行更小chunk的attention计算。
参照ring_flash_attn.py的forward和backward,实现了 blockwise_flash_attn_forward 和 blockwise_flash_attn_backward,目前forward精度可以对齐,backward存在误差。我想问一下,backward的实现可能存在哪些问题?
下面是我的实现:
def blockwise_flash_attn_forward(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
q_chunk_size: int,
k_chunk_size: int,
softmax_scale,
dropout_p=0,
causal=True,
return_softmax=True
):
assert q.shape[1] % q_chunk_size == 0
assert k.shape[1] % k_chunk_size == 0
num_q_chunk = q.shape[1] // q_chunk_size
num_k_chunk = k.shape[1] // k_chunk_size
batch,seqlen,head_dim,num_head = q.shape
block_out = torch.empty(q.shape, dtype=torch.float32, device=q.device)
block_lse = torch.empty((batch,seqlen,head_dim,1), dtype=torch.float32, device=q.device)
for i in range(num_q_chunk):
q_i = q[:,i * q_chunk_size: (i + 1) * q_chunk_size]
out_i = None
lse_i = None
for j in range(num_k_chunk-1,-1,-1):
if j > i and causal:
continue
k_j = k[:,j * k_chunk_size: (j + 1) * k_chunk_size]
v_j = v[:,j * k_chunk_size: (j + 1) * k_chunk_size]
out_ij, _, _, _, _, lse_ij, _, _ = _flash_attn_forward(
q_i,
k_j,
v_j,
dropout_p,
softmax_scale,
causal=causal and j == i,
return_softmax=return_softmax and dropout_p > 0
)
out_i, lse_i = update_out_and_lse(out_i, lse_i, out_ij, lse_ij)
block_out[:, i * q_chunk_size: (i + 1) * q_chunk_size] = out_i
block_lse[:, i * q_chunk_size: (i + 1) * q_chunk_size] = lse_i
return block_out, block_lse.squeeze(dim=-1).transpose(-1,-2)
def blockwise_flash_attn_backward(
dout,
q,
k,
v,
out,
q_chunk_size,
k_chunk_size,
softmax_lse,
dq,
dk,
dv,
softmax_scale,
dropout_p,
causal=True,
rng_state=None
):
assert q.shape[1] % q_chunk_size == 0
assert k.shape[1] % k_chunk_size == 0
num_q_chunk = q.shape[1] // q_chunk_size
num_k_chunk = k.shape[1] // k_chunk_size
temp_dq_buffer = torch.empty(q[:,:q_chunk_size].shape, dtype=q.dtype, device=q.device)
temp_dk_buffer = torch.empty(k[:,:k_chunk_size].shape, dtype=k.dtype, device=k.device)
temp_dv_buffer = torch.empty(v[:,:k_chunk_size].shape, dtype=v.dtype, device=v.device)
for i in range(num_q_chunk):
q_i = q[:,i * q_chunk_size: (i + 1) * q_chunk_size]
dout_i = dout[:,i * q_chunk_size: (i + 1) * q_chunk_size]
out_i = out[:,i * q_chunk_size: (i + 1) * q_chunk_size]
softmax_lse_i = softmax_lse[:,:,i * q_chunk_size: (i + 1) * q_chunk_size]
q_i = q_i.contiguous()
dout_i = dout_i.contiguous()
out_i = out_i.contiguous()
softmax_lse_i = softmax_lse_i.contiguous()
for j in range(num_k_chunk):
k_j = k[:,j * k_chunk_size: (j + 1) * k_chunk_size]
v_j = v[:,j * k_chunk_size: (j + 1) * k_chunk_size]
k_j = k_j.contiguous()
v_j = v_j.contiguous()
if j > i and causal:
continue
_flash_attn_backward(
dout_i,
q_i,
k_j,
v_j,
out_i,
softmax_lse_i,
temp_dq_buffer,
temp_dk_buffer,
temp_dv_buffer,
dropout_p,
softmax_scale,
causal = causal and j == i,
rng_state=rng_state,
)
# update dq dk dv
dq[:,i * q_chunk_size: (i + 1) * q_chunk_size] += temp_dq_buffer
dk[:,j * k_chunk_size: (j + 1) * k_chunk_size] += temp_dk_buffer
dv[:,j * k_chunk_size: (j + 1) * k_chunk_size] += temp_dv_buffer
分别替换ring_flash_attn_forward 中的_flash_attn_forward,和ring_flash_attn_backward中的_flash_attn_backward
下面是我的测试结果:
##############################
# forward:
##############################
out: max 2.896484375, mean 0.0203094482421875
lse: max 10.417832374572754, mean 9.204237937927246
out diff:
[0] max 0.00048828125, mean 8.881092071533203e-06
[1] max 0.0001220703125, mean 7.450580596923828e-06
[2] max 0.0001220703125, mean 5.9604644775390625e-06
[3] max 6.103515625e-05, mean 5.066394805908203e-06
[4] max 6.103515625e-05, mean 4.5299530029296875e-06
[5] max 6.103515625e-05, mean 4.112720489501953e-06
[6] max 6.103515625e-05, mean 3.814697265625e-06
[7] max 6.103515625e-05, mean 3.516674041748047e-06
lse diff:
[0] max 9.5367431640625e-07, mean 1.645181413323371e-07
[1] max 9.5367431640625e-07, mean 2.641230878452916e-07
[2] max 1.9073486328125e-06, mean 3.0044466825529526e-07
[3] max 1.9073486328125e-06, mean 3.3890827921823075e-07
[4] max 1.9073486328125e-06, mean 3.8137659430503845e-07
[5] max 1.9073486328125e-06, mean 4.0913002408160537e-07
[6] max 1.9073486328125e-06, mean 4.272908142866072e-07
[7] max 1.9073486328125e-06, mean 4.6798959374427795e-07
##############################
# backward:
##############################
load_dq:
[0] max 2.783203125, mean 0.052520751953125
[1] max 0.3310546875, mean 0.02398681640625
[2] max 0.2083740234375, mean 0.0184478759765625
[3] max 0.1162109375, mean 0.0155792236328125
[4] max 0.13330078125, mean 0.01374053955078125
[5] max 0.1204833984375, mean 0.01241302490234375
[6] max 0.11260986328125, mean 0.0114288330078125
[7] max 0.0775146484375, mean 0.01064300537109375
dq diff:
[0] max 0.005859375, mean 7.49826431274414e-05
[1] max 0.186279296875, mean 0.01239776611328125
[2] max 0.1973876953125, mean 0.01953125
[3] max 0.235107421875, mean 0.0253143310546875
[4] max 0.30615234375, mean 0.0301361083984375
[5] max 0.52392578125, mean 0.03436279296875
[6] max 0.56689453125, mean 0.038177490234375
[7] max 0.3955078125, mean 0.041748046875
load_dk:
[0] max 2.654296875, mean 0.05340576171875
[1] max 0.256591796875, mean 0.021697998046875
[2] max 0.169921875, mean 0.01535797119140625
[3] max 0.13330078125, mean 0.0116729736328125
[4] max 0.09124755859375, mean 0.0090484619140625
[5] max 0.1158447265625, mean 0.006908416748046875
[6] max 0.050384521484375, mean 0.00492095947265625
[7] max 0.03936767578125, mean 0.002498626708984375
dk diff:
[0] max 0.253173828125, mean 0.03192138671875
[1] max 0.16845703125, mean 0.0232696533203125
[2] max 0.130126953125, mean 0.017364501953125
[3] max 0.1097412109375, mean 0.012786865234375
[4] max 0.10797119140625, mean 0.00893402099609375
[5] max 0.049530029296875, mean 0.005580902099609375
[6] max 0.039337158203125, mean 0.002498626708984375
[7] max 1.52587890625e-05, mean 3.5762786865234375e-07
load_dv:
[0] max 5.89453125, mean 0.05450439453125
[1] max 0.1951904296875, mean 0.021484375
[2] max 0.11883544921875, mean 0.01525115966796875
[3] max 0.10003662109375, mean 0.01158905029296875
[4] max 0.07550048828125, mean 0.00901031494140625
[5] max 0.06658935546875, mean 0.006816864013671875
[6] max 0.041015625, mean 0.00492095947265625
[7] max 0.041961669921875, mean 0.002475738525390625
dv diff:
[0] max 0.3232421875, mean 0.042572021484375
[1] max 0.21240234375, mean 0.03094482421875
[2] max 0.1527099609375, mean 0.0223236083984375
[3] max 0.1075439453125, mean 0.015625
[4] max 0.08245849609375, mean 0.010223388671875
[5] max 0.0447998046875, mean 0.005950927734375
[6] max 0.0419921875, mean 0.002475738525390625
[7] max 3.0517578125e-05, mean 3.5762786865234375e-07
Do we even still need the BPT if we have the ring attention implemented in this repo? @zhuzilin
I personally think BPT is a single-GPU version of ring attention, right?
Do we even still need the BPT if we have the ring attention implemented in this repo? @zhuzilin
I personally think BPT is a single-GPU version of ring attention, right?
That's right, BPT is inherently supported by ring attention. We do not need another implementation.
I'm not sure spliting the sequence length on each device into blocks could save memory (because we still need save buffers and flash_attn itself seems to use linear size memory w.r.t. sequence length), or speed up (because it will call smaller kernels).