kernl icon indicating copy to clipboard operation
kernl copied to clipboard

Backward Kernel Implementation

Open conceptofmind opened this issue 2 years ago • 3 comments
trafficstars

Hi,

I have been trying to make some progress on the backward kernel for training. Unfortunately, I am new to GPU programming and triton so I may be missing parts. Any advice or information you could provide for this implementation:

@triton.jit
def _bwd_preprocess(
    Out, DO, D,
    o_batch_stride, o_head_stride, o_m_stride,
    d_o_batch_stride, d_o_head_stride, d_o_m_stride,
    heads,
    BLOCK_M: tl.constexpr, BLOCK_DHEAD: tl.constexpr,
):
    m_block_idx = tl.program_id(0)
    head_idx = tl.program_id(1)

    current_batch_idx = head_idx // heads
    current_head_idx = head_idx % heads

    range_offs_m = tl.arange(0, BLOCK_M)
    off_m = m_block_idx * BLOCK_M + range_offs_m
    off_n = tl.arange(0, BLOCK_DHEAD)

    # load
    o = tl.load(Out + current_head_idx * o_batch_stride + current_batch_idx * o_head_stride
                + off_m[:, None] * o_m_stride + off_n[None, :]).to(tl.float32)
    d_o = tl.load(DO + current_head_idx * d_o_batch_stride + current_batch_idx + d_o_head_stride
                  + off_m[:, None] * d_o_m_stride + off_n[None, :]).to(tl.float32)
    # compute
    delta = tl.sum(o * d_o, axis=1)
    # write-back
    tl.store(D + head_idx + off_m, delta)

@triton.jit
def _bwd_kernel(
    heads,
    size_m,
    size_n,
    size_m_cache_key,
    size_n_cache_key,
    Q, K, V,
    DO, DQ, DK, DV,
    D,
    sm_scale,
    attention_mask,
    q_batch_stride, q_head_stride, q_m_stride, q_k_stride,
    k_batch_stride, k_head_stride, k_n_stride, k_k_stride, # axis named n,k instead of k,n because of the transpose of K matrix
    v_batch_stride, v_head_stride, v_k_stride, v_n_stride,
    d_o_batch_stride, d_o_head_stride, d_o_m_stride, d_o_n_stride,
    d_q_batch_stride, d_q_head_stride, d_q_m_stride, d_q_k_stride,
    attention_mask_batch_stride, attention_mask_head_stride, attention_mask_m_stride, attention_mask_n_stride,
    min_clamp_value,
    attention_mask_batch_size, attention_mask_head_size, attention_mask_m_size, attention_mask_n_size,
    HAS_MASK: tl.constexpr,
    IS_MATRIX_MASK: tl.constexpr,
    IS_CAUSAL: tl.constexpr,
    BLOCK_DHEAD: tl.constexpr,
    BLOCK_M: tl.constexpr,  # this parameter and below are managed by the autotune and need to be at the end
    BLOCK_N: tl.constexpr,
    NEED_LOAD_MASK_SIZE_M: tl.constexpr,
    NEED_LOAD_MASK_SIZE_N: tl.constexpr,
):
    m_block_idx = tl.program_id(0)
    head_idx = tl.program_id(1)

    range_offs_m = tl.arange(0, BLOCK_M)
    range_offs_n = tl.arange(0, BLOCK_N)
    range_offs_d = tl.arange(0, BLOCK_DHEAD)

    offs_m = tl.program_id(0) * BLOCK_M + range_offs_m

    current_batch_idx = head_idx // heads
    current_head_idx = head_idx % heads

    offs_q = (
            current_batch_idx * q_batch_stride
            + current_head_idx * q_head_stride
            + (offs_m[:, None] * q_m_stride + range_offs_d[None, :] * q_k_stride)
    )

    offs_k = (
            current_batch_idx * k_batch_stride
            + current_head_idx * k_head_stride
            + (range_offs_n[:, None] * k_n_stride + range_offs_d[None, :] * k_k_stride)
    )

    offs_v = (
            current_batch_idx * v_batch_stride
            + current_head_idx * v_head_stride
            + (range_offs_n[:, None] * v_n_stride + range_offs_d[None, :] * v_k_stride)
    )

    offs_d_o = (
            current_batch_idx * d_o_batch_stride
            + current_head_idx * d_o_head_stride
            + (offs_m[:, None] * d_o_m_stride + range_offs_n[None, :] * d_o_n_stride)
    )

    offs_d_q = (
            current_batch_idx * d_q_batch_stride
            + current_head_idx * d_q_head_stride
            + (offs_m[:, None] * d_q_m_stride + range_offs_d[None, :] * d_q_k_stride)
    )

    ptrs_q = Q + offs_q
    ptrs_k = K + offs_k
    ptrs_v = V + offs_v
    ptrs_d_o = DO + offs_d_o
    ptrs_d_q = DQ + offs_d_q

    dv = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    dk = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

    if NEED_LOAD_MASK_SIZE_M | NEED_LOAD_MASK_SIZE_N:
        q = tl.load(ptrs_q, mask=offs_m[:, None] < size_m, other=0.0)
    else:
        q = tl.load(ptrs_q)

    n_end = size_n
    if IS_CAUSAL:
        n_end = ((m_block_idx + 1) * BLOCK_M,)

    if HAS_MASK:
        mask_batch_idx = (current_batch_idx,)
        if attention_mask_batch_size == 1:
            mask_batch_idx = 0

        mask_head_idx = current_head_idx
        if attention_mask_head_size == 1:
            mask_head_idx = 0

        offs_base_mask = mask_batch_idx * attention_mask_batch_stride + mask_head_idx * attention_mask_head_stride

    for block_start_index_n in range(0, size_n, BLOCK_N):

        block_start_index_n = tl.multiple_of(block_start_index_n, BLOCK_N)
        offs_n = block_start_index_n + range_offs_n

        if NEED_LOAD_MASK_SIZE_M:
            k = tl.load(ptrs_k, mask=offs_n[:, None] < size_n, other=0.0)
        else:
            k = tl.load(ptrs_k)

        qk = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

        if NEED_LOAD_MASK_SIZE_N:
            qk = tl.where(range_offs_n[:, None] < size_n, qk, float("-inf"))

        qk += tl.dot(q, k, trans_b=True)
        qk *= sm_scale

        if IS_CAUSAL:
            qk += tl.where(offs_m[:, None] >= offs_n[None, :], 0, float("-inf"))

        if HAS_MASK:
            # we assume mask has a vector shape
            offs_mask = offs_base_mask + offs_n[None, :] * attention_mask_n_stride
            if IS_MATRIX_MASK:  # mask has a matrix shape, we load (BLOCK_M, BLOCK_N) elements
                offs_mask += offs_m[:, None] * attention_mask_m_stride

            if NEED_LOAD_MASK_SIZE_N & (not IS_MATRIX_MASK):  # mask has a vector shape + need a load mask
                attention_load_mask = offs_n[None, :] < attention_mask_n_size
            if IS_MATRIX_MASK:  # mask has a matrix shape
                if NEED_LOAD_MASK_SIZE_M & (not NEED_LOAD_MASK_SIZE_N):  # load mask on M axis
                    attention_load_mask = offs_m[:, None] < attention_mask_m_size
                elif (not NEED_LOAD_MASK_SIZE_M) & NEED_LOAD_MASK_SIZE_N:  # load mask on N axis
                    attention_load_mask = offs_n[None, :] < attention_mask_n_size
                elif NEED_LOAD_MASK_SIZE_M & NEED_LOAD_MASK_SIZE_N:  # load mask on both axis
                    attention_load_mask = (offs_n[None, :] < attention_mask_n_size) & (
                        offs_m[:, None] < attention_mask_m_size
                    )

            if (NEED_LOAD_MASK_SIZE_M & IS_MATRIX_MASK) | NEED_LOAD_MASK_SIZE_N:
                m = tl.load(
                    attention_mask + offs_mask,
                    eviction_policy="evict_first",
                    mask=attention_load_mask,
                    other=float("-inf"),
                )
            else:
                m = tl.load(
                    attention_mask + offs_mask,
                    eviction_policy="evict_first",
                )
            # Avoids NaN
            m = tl.where(m == float("-inf"), min_clamp_value, m)
            qk += m

        p = tl.exp(qk * sm_scale - m[:, None])

        if NEED_LOAD_MASK_SIZE_M:
            d_o = tl.load(ptrs_d_o, mask=offs_m[:, None] < size_m, other=0.0)
        else:
            d_o = tl.load(ptrs_d_o)

        dv += tl.dot(p.to(d_o.dtype), d_o, trans_a=True)

        if NEED_LOAD_MASK_SIZE_N:
            v = tl.load(ptrs_v + block_start_index_n * v_k_stride, mask=offs_n[:, None] < size_n, other=0.0)
        else:
            v = tl.load(ptrs_v + block_start_index_n * v_k_stride)

        dp = tl.dot(d_o, v, trans_b=True)

        d_i = tl.load(D + offs_m)

        ds = p * (dp - d_i[:, None]) * sm_scale

        dk += tl.dot(ds, q, trans_a=True)

        if NEED_LOAD_MASK_SIZE_M:
            d_q = tl.load(ptrs_d_q, mask=offs_m[:, None] < size_m, other=0.0, eviction_policy="evict_last")
        else:
            d_q = tl.load(ptrs_d_q, eviction_policy="evict_last")

        d_q += tl.dot(ds, k)

        tl.store(ptrs_d_q, d_q, eviction_policy="evict_last")

        ptrs_d_q += BLOCK_M
        ptrs_q += BLOCK_M
        ptrs_d_o += BLOCK_M

    ptrs_d_k = DK + offs_d_q
    ptrs_d_v = DV + offs_d_q

    if NEED_LOAD_MASK_SIZE_M | NEED_LOAD_MASK_SIZE_N:
        tl.store(ptrs_d_k, dk, mask=offs_m[:, None] < size_m)
        tl.store(ptrs_d_v, dv, mask=offs_m[:, None] < size_m)
    else:
        tl.store(ptrs_d_k, dk)
        tl.store(ptrs_d_v, dv)

I greatly appreciate any help.

Thank you,

Enrico

conceptofmind avatar Dec 06 '22 18:12 conceptofmind

Thank you a lot for this work. I will try to switch to it soon, in between, have you seen this implementation

https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py

I have myself not tested the bckwd pass but the author seem to have fixed a bunch of Triton bugs (debug barriers mostly).

moreover is there an aspect in the kernel you wrote you think you may need help?

pommedeterresautee avatar Dec 07 '22 05:12 pommedeterresautee

Hi @pommedeterresautee ,

I tried to follow the same code structure as kernl's forward pass to remain consistent when working on my backward kernel implementation above.

I reviewed Tri's implementation you had linked which is definitely more robust. I will have to test and benchmark it soon as well.

In my implementation, I was having difficulty/issues with the _bwd_preprocess function and was unsure how to handle ptrs_d_q, ptrs_q, and ptrs_d_o in the _bwd_kernel:

        ptrs_d_q += BLOCK_M
        ptrs_q += BLOCK_M
        ptrs_d_o += BLOCK_M

In Tri's implementation, he multiplied them by different strides. I do not know if this holds since the initialized offsets are different.

I am also unsure if I am handling loading the variable m correctly in the case where there is no mask. Adding something like if NOT HAS_MASK: m = tl.load(??? + offs_m) before p = tl.exp(qk * sm_scale - m[:, None])?

I greatly appreciate your time and help.

Thank you,

Enrico

conceptofmind avatar Dec 08 '22 03:12 conceptofmind

Hello, I'll try to answer, tell me if I misunderstood the question.

Inside variables named ptrs_* we put memory addresses. And BLOCK_M is just an integer for example 32. So when you say ptrs_d_q += BLOCK_M it means shift all pointer by 32. So you will probably land somewhere you do not want to

  • you will maybe start the load in a middle of a line and loaded lines could even span multiple lines of the original tensor
  • At each round you will probably load almost the same data again and again

The stride of rows (for example) is the distance between each row. So when you are at a position and you add the stride of rows you will land at the same column but on the next row.

So when you multiply BLOCK_M by some stride (for example the matrix rows stride), you will land BLOCK_M rows below.

For the mask I'm not sure to understand, why do you want to load the m if there is no mask. HAS_MASK is the presence or absence of attention_mask tensor.

gaetansnl avatar Dec 14 '22 09:12 gaetansnl