triton icon indicating copy to clipboard operation
triton copied to clipboard

[BUG] Mismatch behavior between for-loop and SIMD

Open PannenetsF opened this issue 1 year ago • 0 comments

Hi everyone, I ran into a problem with whether to use a loop.

The minimal code is the following. The loop version of qk_mul_kernel works fine. But if I changed it to tl.arange, it failed.

I would appreciate it if anyone could help. Thanks.

import torch

import triton
import triton.language as tl
import math

@triton.jit
def qk_mul_kernel(
    q_tensor, 
    q_bs, q_H, q_h,
    k_tensor,
    k_bs, k_L, k_H, k_h,
    score_tensor,
    st_bs, st_H, st_N, st_C,
    BATCH_SIZE: tl.constexpr,
    HEAD_NUM: tl.constexpr,
    HEAD_DIM: tl.constexpr,
    CHUNK_SIZE: tl.constexpr,
    MAX_LENGTH: tl.constexpr,
):
    batch_idx = tl.program_id(0)
    head_idx = tl.program_id(1)
    chunk_idx = tl.program_id(2)

    is_active_block = (chunk_idx < MAX_LENGTH) and (batch_idx < BATCH_SIZE) # the block is indexed
    is_active_head = head_idx < HEAD_NUM # the head is indexed
    if not is_active_block or not is_active_head:
        return
    
    # the iteration of the block/head
    head_off = tl.arange(0, HEAD_DIM)
    block_off = tl.arange(0, CHUNK_SIZE)

    # q vector
    q_off = batch_idx * q_bs + head_idx * q_H + head_off * q_h
    q_vec = tl.load(q_tensor + q_off)
    q_vec = tl.view(q_vec, (1, HEAD_DIM))

    # k matrix
    leng_off = block_off + chunk_idx * CHUNK_SIZE
    leng_off_2d = tl.view(leng_off, (CHUNK_SIZE, 1))
    head_off_2d = tl.view(head_off, (1, HEAD_DIM))
    leng_mask = leng_off < MAX_LENGTH
    leng_mask_2d = tl.view(leng_mask, (CHUNK_SIZE, 1))
    k_off = batch_idx * k_bs + leng_off_2d * k_L + head_idx * k_H +  head_off_2d * k_h 
    k_mat = tl.load(k_tensor + k_off, mask=leng_mask_2d, other=0.) # (CHUNK_SIZE, HEAD_DIM)
    
    # score
    score_off = batch_idx * st_bs + head_idx * st_H + chunk_idx * st_N + block_off * st_C
    score_vec = tl.sum(q_vec * k_mat, axis=1)
    tl.store(score_tensor + score_off, score_vec)


@triton.jit
def qk_mul_kernel_loop(
    q_tensor, 
    q_bs, q_H, q_h,
    k_tensor,
    k_bs, k_L, k_H, k_h,
    score_tensor,
    st_bs, st_H, st_N, st_C,
    BATCH_SIZE: tl.constexpr,
    HEAD_NUM: tl.constexpr,
    HEAD_DIM: tl.constexpr,
    CHUNK_SIZE: tl.constexpr,
    MAX_LENGTH: tl.constexpr,
):
    batch_idx = tl.program_id(0)
    head_idx = tl.program_id(1)
    chunk_idx = tl.program_id(2)

    is_active_block = (chunk_idx < MAX_LENGTH) and (batch_idx < BATCH_SIZE) # the block is indexed
    is_active_head = head_idx < HEAD_NUM # the head is indexed
    if not is_active_block or not is_active_head:
        return
    
    # the iteration of the block/head
    head_off = tl.arange(0, HEAD_DIM)

    # q vector
    q_off = batch_idx * q_bs + head_idx * q_H + head_off * q_h
    q_vec = tl.load(q_tensor + q_off)

    for block_off in tl.static_range(CHUNK_SIZE):
        # k matrix
        leng_off = block_off + chunk_idx * CHUNK_SIZE
        leng_mask = leng_off < MAX_LENGTH
        k_off = batch_idx * k_bs + leng_off * k_L + head_idx * k_H +  head_off * k_h 
        k_mat = tl.load(k_tensor + k_off, mask=leng_mask, other=0.) # (HEAD_DIM)
        
        # score
        score_off = batch_idx * st_bs + head_idx * st_H + chunk_idx * st_N + block_off * st_C
        score_vec = tl.sum(q_vec * k_mat)
        tl.store(score_tensor + score_off, score_vec)

@torch.no_grad()
def qk_mul(
    q_tensor,
    k_tensor,
    score_tensor,
    method='vec'
):

    # score (bs, head_num, seq_len / chunk_size, chunk_size)
    BATCH_SIZE = q_tensor.shape[0]
    HEAD_NUM = q_tensor.shape[1]
    HEAD_DIM = q_tensor.shape[2]
    CHUNK_SIZE = 64
    MAX_LENGTH = k_tensor.shape[1]
    grid = lambda META: (BATCH_SIZE, HEAD_NUM, math.ceil(MAX_LENGTH / CHUNK_SIZE))
    q_bs, q_H, q_h = q_tensor.stride(0), q_tensor.stride(1), q_tensor.stride(2)
    k_bs, k_L, k_H, k_h = k_tensor.stride(0), k_tensor.stride(1), k_tensor.stride(2), k_tensor.stride(3)
    st_bs, st_H, st_N, st_C = score_tensor.stride(0), score_tensor.stride(1), score_tensor.stride(2), score_tensor.stride(3)
    if method == 'loop':
        qk_mul_kernel_loop[grid](q_tensor, 
                            q_bs, q_H, q_h,
                            k_tensor,
                            k_bs, k_L, k_H, k_h,
                            score_tensor,
                            st_bs, st_H, st_N, st_C,
                            BATCH_SIZE, HEAD_NUM, HEAD_DIM, CHUNK_SIZE, MAX_LENGTH)
    else:
        qk_mul_kernel[grid](q_tensor, 
                            q_bs, q_H, q_h,
                            k_tensor,
                            k_bs, k_L, k_H, k_h,
                            score_tensor,
                            st_bs, st_H, st_N, st_C,
                            BATCH_SIZE, HEAD_NUM, HEAD_DIM, CHUNK_SIZE, MAX_LENGTH)

@torch.no_grad()
def torch_qk_mul(
    q_tensor, 
    k_tensor,
    triton_score_tensor,
):
    # q (bs, head_num, head_dim)
    # k (bs, seq_len, head_num, head_dim)
    # k_t (bs, head_num, seq_len, head_dim)
    # q_t (bs, head_num, head_dim, 1)
    # score (bs, head_num, seq_len, 1)
    bs, seq_len, head_num, head_dim = k_tensor.shape
    k_t = k_tensor.detach().clone().permute(0, 2, 1, 3).contiguous()
    q_t = q_tensor.detach().clone().reshape(bs, head_num, head_dim, 1).contiguous()
    score_tensor = torch.matmul(k_t, q_t)
    pad_len = math.ceil(seq_len / 64) * 64
    pad_size = pad_len - seq_len
    if pad_size > 0:
        score_tensor = torch.cat([score_tensor, torch.zeros(bs, head_num, pad_size, 1, device=score_tensor.device)], dim=2)
    score_tensor = score_tensor.reshape(bs, head_num, -1)
    
    # triton_score_tensor (bs, head_num, seq_len, chunk_size)
    triton_score_tensor = triton_score_tensor.detach().clone().reshape(bs, head_num, -1)
    
    st = score_tensor.reshape(-1, pad_len)
    tst = triton_score_tensor.reshape(-1, pad_len)
    # cos = torch.nn.functional.cosine_similarity(st, tst, dim=0)
    assert torch.allclose(st, tst, atol=1e-5), "not equal"
    

if __name__ == "__main__":
    bs = 4
    head_num = 8
    seq_len = 60
    head_dim = 16
    q_tensor = torch.randn(bs, head_num, head_dim, device='cuda:0')
    k_tensor = torch.randn(bs, seq_len, head_num, head_dim, device='cuda:0')
    score_tensor = torch.zeros(bs, head_num, math.ceil(seq_len / 64), 64, device='cuda:0')
    qk_mul(q_tensor, k_tensor, score_tensor, 'loop')
    torch_qk_mul(q_tensor, k_tensor, score_tensor)
    print('the loop passed')
    qk_mul(q_tensor, k_tensor, score_tensor)
    torch_qk_mul(q_tensor, k_tensor, score_tensor)
    print('the vec passed')

PannenetsF avatar Dec 23 '23 08:12 PannenetsF