triton
triton copied to clipboard
[BUG] Mismatch behavior between for-loop and SIMD
Hi everyone, I ran into a problem with whether to use a loop.
The minimal code is the following. The loop version of qk_mul_kernel
works fine. But if I changed it to tl.arange
, it failed.
I would appreciate it if anyone could help. Thanks.
import torch
import triton
import triton.language as tl
import math
@triton.jit
def qk_mul_kernel(
q_tensor,
q_bs, q_H, q_h,
k_tensor,
k_bs, k_L, k_H, k_h,
score_tensor,
st_bs, st_H, st_N, st_C,
BATCH_SIZE: tl.constexpr,
HEAD_NUM: tl.constexpr,
HEAD_DIM: tl.constexpr,
CHUNK_SIZE: tl.constexpr,
MAX_LENGTH: tl.constexpr,
):
batch_idx = tl.program_id(0)
head_idx = tl.program_id(1)
chunk_idx = tl.program_id(2)
is_active_block = (chunk_idx < MAX_LENGTH) and (batch_idx < BATCH_SIZE) # the block is indexed
is_active_head = head_idx < HEAD_NUM # the head is indexed
if not is_active_block or not is_active_head:
return
# the iteration of the block/head
head_off = tl.arange(0, HEAD_DIM)
block_off = tl.arange(0, CHUNK_SIZE)
# q vector
q_off = batch_idx * q_bs + head_idx * q_H + head_off * q_h
q_vec = tl.load(q_tensor + q_off)
q_vec = tl.view(q_vec, (1, HEAD_DIM))
# k matrix
leng_off = block_off + chunk_idx * CHUNK_SIZE
leng_off_2d = tl.view(leng_off, (CHUNK_SIZE, 1))
head_off_2d = tl.view(head_off, (1, HEAD_DIM))
leng_mask = leng_off < MAX_LENGTH
leng_mask_2d = tl.view(leng_mask, (CHUNK_SIZE, 1))
k_off = batch_idx * k_bs + leng_off_2d * k_L + head_idx * k_H + head_off_2d * k_h
k_mat = tl.load(k_tensor + k_off, mask=leng_mask_2d, other=0.) # (CHUNK_SIZE, HEAD_DIM)
# score
score_off = batch_idx * st_bs + head_idx * st_H + chunk_idx * st_N + block_off * st_C
score_vec = tl.sum(q_vec * k_mat, axis=1)
tl.store(score_tensor + score_off, score_vec)
@triton.jit
def qk_mul_kernel_loop(
q_tensor,
q_bs, q_H, q_h,
k_tensor,
k_bs, k_L, k_H, k_h,
score_tensor,
st_bs, st_H, st_N, st_C,
BATCH_SIZE: tl.constexpr,
HEAD_NUM: tl.constexpr,
HEAD_DIM: tl.constexpr,
CHUNK_SIZE: tl.constexpr,
MAX_LENGTH: tl.constexpr,
):
batch_idx = tl.program_id(0)
head_idx = tl.program_id(1)
chunk_idx = tl.program_id(2)
is_active_block = (chunk_idx < MAX_LENGTH) and (batch_idx < BATCH_SIZE) # the block is indexed
is_active_head = head_idx < HEAD_NUM # the head is indexed
if not is_active_block or not is_active_head:
return
# the iteration of the block/head
head_off = tl.arange(0, HEAD_DIM)
# q vector
q_off = batch_idx * q_bs + head_idx * q_H + head_off * q_h
q_vec = tl.load(q_tensor + q_off)
for block_off in tl.static_range(CHUNK_SIZE):
# k matrix
leng_off = block_off + chunk_idx * CHUNK_SIZE
leng_mask = leng_off < MAX_LENGTH
k_off = batch_idx * k_bs + leng_off * k_L + head_idx * k_H + head_off * k_h
k_mat = tl.load(k_tensor + k_off, mask=leng_mask, other=0.) # (HEAD_DIM)
# score
score_off = batch_idx * st_bs + head_idx * st_H + chunk_idx * st_N + block_off * st_C
score_vec = tl.sum(q_vec * k_mat)
tl.store(score_tensor + score_off, score_vec)
@torch.no_grad()
def qk_mul(
q_tensor,
k_tensor,
score_tensor,
method='vec'
):
# score (bs, head_num, seq_len / chunk_size, chunk_size)
BATCH_SIZE = q_tensor.shape[0]
HEAD_NUM = q_tensor.shape[1]
HEAD_DIM = q_tensor.shape[2]
CHUNK_SIZE = 64
MAX_LENGTH = k_tensor.shape[1]
grid = lambda META: (BATCH_SIZE, HEAD_NUM, math.ceil(MAX_LENGTH / CHUNK_SIZE))
q_bs, q_H, q_h = q_tensor.stride(0), q_tensor.stride(1), q_tensor.stride(2)
k_bs, k_L, k_H, k_h = k_tensor.stride(0), k_tensor.stride(1), k_tensor.stride(2), k_tensor.stride(3)
st_bs, st_H, st_N, st_C = score_tensor.stride(0), score_tensor.stride(1), score_tensor.stride(2), score_tensor.stride(3)
if method == 'loop':
qk_mul_kernel_loop[grid](q_tensor,
q_bs, q_H, q_h,
k_tensor,
k_bs, k_L, k_H, k_h,
score_tensor,
st_bs, st_H, st_N, st_C,
BATCH_SIZE, HEAD_NUM, HEAD_DIM, CHUNK_SIZE, MAX_LENGTH)
else:
qk_mul_kernel[grid](q_tensor,
q_bs, q_H, q_h,
k_tensor,
k_bs, k_L, k_H, k_h,
score_tensor,
st_bs, st_H, st_N, st_C,
BATCH_SIZE, HEAD_NUM, HEAD_DIM, CHUNK_SIZE, MAX_LENGTH)
@torch.no_grad()
def torch_qk_mul(
q_tensor,
k_tensor,
triton_score_tensor,
):
# q (bs, head_num, head_dim)
# k (bs, seq_len, head_num, head_dim)
# k_t (bs, head_num, seq_len, head_dim)
# q_t (bs, head_num, head_dim, 1)
# score (bs, head_num, seq_len, 1)
bs, seq_len, head_num, head_dim = k_tensor.shape
k_t = k_tensor.detach().clone().permute(0, 2, 1, 3).contiguous()
q_t = q_tensor.detach().clone().reshape(bs, head_num, head_dim, 1).contiguous()
score_tensor = torch.matmul(k_t, q_t)
pad_len = math.ceil(seq_len / 64) * 64
pad_size = pad_len - seq_len
if pad_size > 0:
score_tensor = torch.cat([score_tensor, torch.zeros(bs, head_num, pad_size, 1, device=score_tensor.device)], dim=2)
score_tensor = score_tensor.reshape(bs, head_num, -1)
# triton_score_tensor (bs, head_num, seq_len, chunk_size)
triton_score_tensor = triton_score_tensor.detach().clone().reshape(bs, head_num, -1)
st = score_tensor.reshape(-1, pad_len)
tst = triton_score_tensor.reshape(-1, pad_len)
# cos = torch.nn.functional.cosine_similarity(st, tst, dim=0)
assert torch.allclose(st, tst, atol=1e-5), "not equal"
if __name__ == "__main__":
bs = 4
head_num = 8
seq_len = 60
head_dim = 16
q_tensor = torch.randn(bs, head_num, head_dim, device='cuda:0')
k_tensor = torch.randn(bs, seq_len, head_num, head_dim, device='cuda:0')
score_tensor = torch.zeros(bs, head_num, math.ceil(seq_len / 64), 64, device='cuda:0')
qk_mul(q_tensor, k_tensor, score_tensor, 'loop')
torch_qk_mul(q_tensor, k_tensor, score_tensor)
print('the loop passed')
qk_mul(q_tensor, k_tensor, score_tensor)
torch_qk_mul(q_tensor, k_tensor, score_tensor)
print('the vec passed')