kernl
kernl copied to clipboard
Backward Kernel Implementation
Hi,
I have been trying to make some progress on the backward kernel for training. Unfortunately, I am new to GPU programming and triton so I may be missing parts. Any advice or information you could provide for this implementation:
@triton.jit
def _bwd_preprocess(
Out, DO, D,
o_batch_stride, o_head_stride, o_m_stride,
d_o_batch_stride, d_o_head_stride, d_o_m_stride,
heads,
BLOCK_M: tl.constexpr, BLOCK_DHEAD: tl.constexpr,
):
m_block_idx = tl.program_id(0)
head_idx = tl.program_id(1)
current_batch_idx = head_idx // heads
current_head_idx = head_idx % heads
range_offs_m = tl.arange(0, BLOCK_M)
off_m = m_block_idx * BLOCK_M + range_offs_m
off_n = tl.arange(0, BLOCK_DHEAD)
# load
o = tl.load(Out + current_head_idx * o_batch_stride + current_batch_idx * o_head_stride
+ off_m[:, None] * o_m_stride + off_n[None, :]).to(tl.float32)
d_o = tl.load(DO + current_head_idx * d_o_batch_stride + current_batch_idx + d_o_head_stride
+ off_m[:, None] * d_o_m_stride + off_n[None, :]).to(tl.float32)
# compute
delta = tl.sum(o * d_o, axis=1)
# write-back
tl.store(D + head_idx + off_m, delta)
@triton.jit
def _bwd_kernel(
heads,
size_m,
size_n,
size_m_cache_key,
size_n_cache_key,
Q, K, V,
DO, DQ, DK, DV,
D,
sm_scale,
attention_mask,
q_batch_stride, q_head_stride, q_m_stride, q_k_stride,
k_batch_stride, k_head_stride, k_n_stride, k_k_stride, # axis named n,k instead of k,n because of the transpose of K matrix
v_batch_stride, v_head_stride, v_k_stride, v_n_stride,
d_o_batch_stride, d_o_head_stride, d_o_m_stride, d_o_n_stride,
d_q_batch_stride, d_q_head_stride, d_q_m_stride, d_q_k_stride,
attention_mask_batch_stride, attention_mask_head_stride, attention_mask_m_stride, attention_mask_n_stride,
min_clamp_value,
attention_mask_batch_size, attention_mask_head_size, attention_mask_m_size, attention_mask_n_size,
HAS_MASK: tl.constexpr,
IS_MATRIX_MASK: tl.constexpr,
IS_CAUSAL: tl.constexpr,
BLOCK_DHEAD: tl.constexpr,
BLOCK_M: tl.constexpr, # this parameter and below are managed by the autotune and need to be at the end
BLOCK_N: tl.constexpr,
NEED_LOAD_MASK_SIZE_M: tl.constexpr,
NEED_LOAD_MASK_SIZE_N: tl.constexpr,
):
m_block_idx = tl.program_id(0)
head_idx = tl.program_id(1)
range_offs_m = tl.arange(0, BLOCK_M)
range_offs_n = tl.arange(0, BLOCK_N)
range_offs_d = tl.arange(0, BLOCK_DHEAD)
offs_m = tl.program_id(0) * BLOCK_M + range_offs_m
current_batch_idx = head_idx // heads
current_head_idx = head_idx % heads
offs_q = (
current_batch_idx * q_batch_stride
+ current_head_idx * q_head_stride
+ (offs_m[:, None] * q_m_stride + range_offs_d[None, :] * q_k_stride)
)
offs_k = (
current_batch_idx * k_batch_stride
+ current_head_idx * k_head_stride
+ (range_offs_n[:, None] * k_n_stride + range_offs_d[None, :] * k_k_stride)
)
offs_v = (
current_batch_idx * v_batch_stride
+ current_head_idx * v_head_stride
+ (range_offs_n[:, None] * v_n_stride + range_offs_d[None, :] * v_k_stride)
)
offs_d_o = (
current_batch_idx * d_o_batch_stride
+ current_head_idx * d_o_head_stride
+ (offs_m[:, None] * d_o_m_stride + range_offs_n[None, :] * d_o_n_stride)
)
offs_d_q = (
current_batch_idx * d_q_batch_stride
+ current_head_idx * d_q_head_stride
+ (offs_m[:, None] * d_q_m_stride + range_offs_d[None, :] * d_q_k_stride)
)
ptrs_q = Q + offs_q
ptrs_k = K + offs_k
ptrs_v = V + offs_v
ptrs_d_o = DO + offs_d_o
ptrs_d_q = DQ + offs_d_q
dv = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
dk = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
if NEED_LOAD_MASK_SIZE_M | NEED_LOAD_MASK_SIZE_N:
q = tl.load(ptrs_q, mask=offs_m[:, None] < size_m, other=0.0)
else:
q = tl.load(ptrs_q)
n_end = size_n
if IS_CAUSAL:
n_end = ((m_block_idx + 1) * BLOCK_M,)
if HAS_MASK:
mask_batch_idx = (current_batch_idx,)
if attention_mask_batch_size == 1:
mask_batch_idx = 0
mask_head_idx = current_head_idx
if attention_mask_head_size == 1:
mask_head_idx = 0
offs_base_mask = mask_batch_idx * attention_mask_batch_stride + mask_head_idx * attention_mask_head_stride
for block_start_index_n in range(0, size_n, BLOCK_N):
block_start_index_n = tl.multiple_of(block_start_index_n, BLOCK_N)
offs_n = block_start_index_n + range_offs_n
if NEED_LOAD_MASK_SIZE_M:
k = tl.load(ptrs_k, mask=offs_n[:, None] < size_n, other=0.0)
else:
k = tl.load(ptrs_k)
qk = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
if NEED_LOAD_MASK_SIZE_N:
qk = tl.where(range_offs_n[:, None] < size_n, qk, float("-inf"))
qk += tl.dot(q, k, trans_b=True)
qk *= sm_scale
if IS_CAUSAL:
qk += tl.where(offs_m[:, None] >= offs_n[None, :], 0, float("-inf"))
if HAS_MASK:
# we assume mask has a vector shape
offs_mask = offs_base_mask + offs_n[None, :] * attention_mask_n_stride
if IS_MATRIX_MASK: # mask has a matrix shape, we load (BLOCK_M, BLOCK_N) elements
offs_mask += offs_m[:, None] * attention_mask_m_stride
if NEED_LOAD_MASK_SIZE_N & (not IS_MATRIX_MASK): # mask has a vector shape + need a load mask
attention_load_mask = offs_n[None, :] < attention_mask_n_size
if IS_MATRIX_MASK: # mask has a matrix shape
if NEED_LOAD_MASK_SIZE_M & (not NEED_LOAD_MASK_SIZE_N): # load mask on M axis
attention_load_mask = offs_m[:, None] < attention_mask_m_size
elif (not NEED_LOAD_MASK_SIZE_M) & NEED_LOAD_MASK_SIZE_N: # load mask on N axis
attention_load_mask = offs_n[None, :] < attention_mask_n_size
elif NEED_LOAD_MASK_SIZE_M & NEED_LOAD_MASK_SIZE_N: # load mask on both axis
attention_load_mask = (offs_n[None, :] < attention_mask_n_size) & (
offs_m[:, None] < attention_mask_m_size
)
if (NEED_LOAD_MASK_SIZE_M & IS_MATRIX_MASK) | NEED_LOAD_MASK_SIZE_N:
m = tl.load(
attention_mask + offs_mask,
eviction_policy="evict_first",
mask=attention_load_mask,
other=float("-inf"),
)
else:
m = tl.load(
attention_mask + offs_mask,
eviction_policy="evict_first",
)
# Avoids NaN
m = tl.where(m == float("-inf"), min_clamp_value, m)
qk += m
p = tl.exp(qk * sm_scale - m[:, None])
if NEED_LOAD_MASK_SIZE_M:
d_o = tl.load(ptrs_d_o, mask=offs_m[:, None] < size_m, other=0.0)
else:
d_o = tl.load(ptrs_d_o)
dv += tl.dot(p.to(d_o.dtype), d_o, trans_a=True)
if NEED_LOAD_MASK_SIZE_N:
v = tl.load(ptrs_v + block_start_index_n * v_k_stride, mask=offs_n[:, None] < size_n, other=0.0)
else:
v = tl.load(ptrs_v + block_start_index_n * v_k_stride)
dp = tl.dot(d_o, v, trans_b=True)
d_i = tl.load(D + offs_m)
ds = p * (dp - d_i[:, None]) * sm_scale
dk += tl.dot(ds, q, trans_a=True)
if NEED_LOAD_MASK_SIZE_M:
d_q = tl.load(ptrs_d_q, mask=offs_m[:, None] < size_m, other=0.0, eviction_policy="evict_last")
else:
d_q = tl.load(ptrs_d_q, eviction_policy="evict_last")
d_q += tl.dot(ds, k)
tl.store(ptrs_d_q, d_q, eviction_policy="evict_last")
ptrs_d_q += BLOCK_M
ptrs_q += BLOCK_M
ptrs_d_o += BLOCK_M
ptrs_d_k = DK + offs_d_q
ptrs_d_v = DV + offs_d_q
if NEED_LOAD_MASK_SIZE_M | NEED_LOAD_MASK_SIZE_N:
tl.store(ptrs_d_k, dk, mask=offs_m[:, None] < size_m)
tl.store(ptrs_d_v, dv, mask=offs_m[:, None] < size_m)
else:
tl.store(ptrs_d_k, dk)
tl.store(ptrs_d_v, dv)
I greatly appreciate any help.
Thank you,
Enrico
Thank you a lot for this work. I will try to switch to it soon, in between, have you seen this implementation
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py
I have myself not tested the bckwd pass but the author seem to have fixed a bunch of Triton bugs (debug barriers mostly).
moreover is there an aspect in the kernel you wrote you think you may need help?
Hi @pommedeterresautee ,
I tried to follow the same code structure as kernl's forward pass to remain consistent when working on my backward kernel implementation above.
I reviewed Tri's implementation you had linked which is definitely more robust. I will have to test and benchmark it soon as well.
In my implementation, I was having difficulty/issues with the _bwd_preprocess function and was unsure how to handle ptrs_d_q, ptrs_q, and ptrs_d_o in the _bwd_kernel:
ptrs_d_q += BLOCK_M
ptrs_q += BLOCK_M
ptrs_d_o += BLOCK_M
In Tri's implementation, he multiplied them by different strides. I do not know if this holds since the initialized offsets are different.
I am also unsure if I am handling loading the variable m correctly in the case where there is no mask. Adding something like if NOT HAS_MASK: m = tl.load(??? + offs_m) before p = tl.exp(qk * sm_scale - m[:, None])?
I greatly appreciate your time and help.
Thank you,
Enrico
Hello, I'll try to answer, tell me if I misunderstood the question.
Inside variables named ptrs_* we put memory addresses. And BLOCK_M is just an integer for example 32. So when you say ptrs_d_q += BLOCK_M it means shift all pointer by 32. So you will probably land somewhere you do not want to
- you will maybe start the load in a middle of a line and loaded lines could even span multiple lines of the original tensor
- At each round you will probably load almost the same data again and again
The stride of rows (for example) is the distance between each row. So when you are at a position and you add the stride of rows you will land at the same column but on the next row.
So when you multiply BLOCK_M by some stride (for example the matrix rows stride), you will land BLOCK_M rows below.
For the mask I'm not sure to understand, why do you want to load the m if there is no mask. HAS_MASK is the presence or absence of attention_mask tensor.