BUG in flash attention kernel
import math
import pytest
import torch
import triton
import triton.language as tl
# from triton.runtime.interpreter import TensorHandle
@triton.jit
def _fwd_kernel2(
Q,
K,
V,
# POS_EMB2,
sm_scale,
Out,
# stride_e0,
# stride_e1,
stride_qh,
stride_qm,
stride_qk,
stride_kh,
stride_kn,
stride_kk,
stride_vh,
stride_vk,
stride_vn,
stride_oh,
stride_om,
stride_on,
N_CTX,
N_CTX_2,
HIDDEN_DIM,
# emb_len,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
qvk_offset = off_hz * stride_qh
Q_block_ptr = tl.make_block_ptr(
base=Q + qvk_offset,
shape=(N_CTX, HIDDEN_DIM),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
K_block_ptr = tl.make_block_ptr(
base=K + qvk_offset,
shape=(HIDDEN_DIM, N_CTX_2),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1),
)
V_block_ptr = tl.make_block_ptr(
base=V + qvk_offset,
shape=(N_CTX_2, HIDDEN_DIM),
strides=(stride_vk, stride_vn),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0),
)
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
# l_i = tl.where(offs_m < N_CTX, l_i, 1)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# credits to: Adam P. Goucher (https://github.com/apgoucher):
# scale sm_scale by 1/log_2(e) and use
# 2^x instead of exp in the loop because CSE and LICM
# don't work as expected with `exp` in the loop
qk_scale = sm_scale * 1.44269504
# load q: it will stay in SRAM throughout
q = tl.load(Q_block_ptr, padding_option="zero", boundary_check=(0, 1))
q = (q * qk_scale).to(K.dtype.element_ty)
lo = 0
hi = N_CTX_2
for start_n in range(lo, hi, BLOCK_N):
# -- load k, v --
k = tl.load(K_block_ptr, padding_option="zero", boundary_check=(0, 1))
v = tl.load(V_block_ptr, padding_option="zero", boundary_check=(0, 1))
# -- compute qk ---
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
# if IS_CAUSAL:
# qk = tl.where(
# offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")
# )
qk += tl.dot(q, k, allow_tf32=True)
# -- compute scaling constant ---
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
alpha = tl.math.exp2(m_i - m_i_new)
p = tl.math.exp2(qk - m_i_new[:, None])
# -- scale and update acc --
acc_scale = l_i * 0 + alpha # workaround some compiler bug
acc *= acc_scale[:, None]
acc += tl.dot(p.to(V.dtype.element_ty), v, allow_tf32=True)
# -- update m_i and l_i --
l_i = l_i * alpha + tl.sum(p, 1)
m_i = m_i_new
# update pointers
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
# write back l and m
acc = acc / l_i[:, None]
# write back O
O_block_ptr = tl.make_block_ptr(
base=Out + qvk_offset,
shape=(N_CTX, HIDDEN_DIM),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
tl.store(O_block_ptr, acc.to(K.dtype.element_ty), boundary_check=(0, 1))
def forward(q, k, v, sm_scale, sequence_parallel=False):
# only support for Ampere now
capability = torch.cuda.get_device_capability()
if capability[0] < 8:
raise RuntimeError(
"Flash attention currently only supported for compute capability >= 80"
)
BLOCK_M = 128
BLOCK_N = 64
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
BLOCK_HEADDIM = max(triton.next_power_of_2(Lk), 16)
# assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q)
grid = (triton.cdiv(q.shape[0], BLOCK_M), q.shape[1], 1)
num_warps = 4 if Lk <= 64 else 8
_fwd_kernel2[grid](
q,
k,
v,
sm_scale,
o,
q.stride(1),
q.stride(0),
q.stride(2),
k.stride(1),
k.stride(0),
k.stride(2),
v.stride(1),
v.stride(0),
v.stride(2),
o.stride(1),
o.stride(0),
o.stride(2),
q.shape[0],
k.shape[0],
Lk,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_DMODEL=BLOCK_HEADDIM,
num_warps=num_warps,
num_stages=4,
)
return o
# @pytest.mark.parametrize("batch, seqlen_q, nheads, d,", [(1, 2, 1024, 64)])
# @pytest.mark.parametrize("causal", [True])
@torch.no_grad()
def test_op(seqlen, nheads, d, dtype=torch.float16, q_ctx=None):
if q_ctx == None:
q_ctx = seqlen
device = "cuda"
assert d <= 128, "FlashAttention only support head dimensions up to 128"
torch.manual_seed(20)
q = torch.empty((q_ctx, nheads, d), dtype=dtype, device="cuda").normal_(
mean=0.0, std=0.5
)
k = torch.empty(( seqlen, nheads,d), dtype=dtype, device="cuda").normal_(
mean=0.0, std=0.5
)
v = torch.empty((seqlen, nheads, d), dtype=dtype, device="cuda").normal_(
mean=0.0, std=0.5
)
sm_scale = 0.5
tri_out = forward(
q.to(device),
k.to(device),
v.to(device),
# pos_emb.to(device),
sm_scale=sm_scale,
).to(dtype)
# reference implementation
dots = torch.matmul(q.transpose(0,1), k.transpose(0,1).transpose(-1, -2)) * sm_scale
attn = torch.softmax(
dots.float(), axis=-1
).half()
ref_out = torch.matmul(attn, v.transpose(0,1).half()).detach().to(dtype).to(device).transpose(0,1)
# triton implementation
# compare
# assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0)
print("max diff: ", (ref_out - tri_out).abs().max().item())
if __name__ == "__main__":
test_op(4096, 8, 32, torch.float16)
test_op(4096, 8, 32, torch.float16, q_ctx=100)
test_op(4096, 800, 32, torch.float16, q_ctx=10)
max diff: 3.814697265625e-05
max diff: 3.0517578125e-05
max diff: 3.0517578125e-05
This version give correct output, but when I change the qkv layout from (seqlen, nheads, d) to (nheads, seqlen, d), I got wrong result.
import math
import pytest
import torch
import triton
import triton.language as tl
# from triton.runtime.interpreter import TensorHandle
@triton.jit
def _fwd_kernel2(
Q,
K,
V,
# POS_EMB2,
sm_scale,
Out,
# stride_e0,
# stride_e1,
stride_qh,
stride_qm,
stride_qk,
stride_kh,
stride_kn,
stride_kk,
stride_vh,
stride_vk,
stride_vn,
stride_oh,
stride_om,
stride_on,
N_CTX,
N_CTX_2,
HIDDEN_DIM,
# emb_len,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
qvk_offset = off_hz * stride_qh
Q_block_ptr = tl.make_block_ptr(
base=Q + qvk_offset,
shape=(N_CTX, HIDDEN_DIM),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
K_block_ptr = tl.make_block_ptr(
base=K + qvk_offset,
shape=(HIDDEN_DIM, N_CTX_2),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1),
)
V_block_ptr = tl.make_block_ptr(
base=V + qvk_offset,
shape=(N_CTX_2, HIDDEN_DIM),
strides=(stride_vk, stride_vn),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0),
)
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
# l_i = tl.where(offs_m < N_CTX, l_i, 1)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# credits to: Adam P. Goucher (https://github.com/apgoucher):
# scale sm_scale by 1/log_2(e) and use
# 2^x instead of exp in the loop because CSE and LICM
# don't work as expected with `exp` in the loop
qk_scale = sm_scale * 1.44269504
# load q: it will stay in SRAM throughout
q = tl.load(Q_block_ptr, padding_option="zero", boundary_check=(0, 1))
q = (q * qk_scale).to(K.dtype.element_ty)
lo = 0
hi = N_CTX_2
for start_n in range(lo, hi, BLOCK_N):
# -- load k, v --
k = tl.load(K_block_ptr, padding_option="zero", boundary_check=(0, 1))
v = tl.load(V_block_ptr, padding_option="zero", boundary_check=(0, 1))
# -- compute qk ---
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
# if IS_CAUSAL:
# qk = tl.where(
# offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")
# )
qk += tl.dot(q, k, allow_tf32=True)
# -- compute scaling constant ---
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
alpha = tl.math.exp2(m_i - m_i_new)
p = tl.math.exp2(qk - m_i_new[:, None])
# -- scale and update acc --
acc_scale = l_i * 0 + alpha # workaround some compiler bug
acc *= acc_scale[:, None]
acc += tl.dot(p.to(V.dtype.element_ty), v, allow_tf32=True)
# -- update m_i and l_i --
l_i = l_i * alpha + tl.sum(p, 1)
m_i = m_i_new
# update pointers
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
# write back l and m
acc = acc / l_i[:, None]
# write back O
O_block_ptr = tl.make_block_ptr(
base=Out + qvk_offset,
shape=(N_CTX, HIDDEN_DIM),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
tl.store(O_block_ptr, acc.to(K.dtype.element_ty), boundary_check=(0, 1))
def forward(q, k, v, sm_scale, sequence_parallel=False):
# only support for Ampere now
capability = torch.cuda.get_device_capability()
if capability[0] < 8:
raise RuntimeError(
"Flash attention currently only supported for compute capability >= 80"
)
BLOCK_M = 128
BLOCK_N = 64
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
BLOCK_HEADDIM = max(triton.next_power_of_2(Lk), 16)
# assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q)
grid = (triton.cdiv(q.shape[1], BLOCK_M), q.shape[0], 1)
num_warps = 4 if Lk <= 64 else 8
_fwd_kernel2[grid](
q,
k,
v,
sm_scale,
o,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
q.shape[1],
k.shape[1],
Lk,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_DMODEL=BLOCK_HEADDIM,
num_warps=num_warps,
num_stages=4,
)
return o
# @pytest.mark.parametrize("batch, seqlen_q, nheads, d,", [(1, 2, 1024, 64)])
# @pytest.mark.parametrize("causal", [True])
@torch.no_grad()
def test_op(nheads, seqlen_q, d, dtype=torch.float16, q_ctx=None):
if q_ctx == None:
q_ctx = seqlen_q
device = "cuda"
assert d <= 128, "FlashAttention only support head dimensions up to 128"
torch.manual_seed(20)
q = torch.empty((nheads, q_ctx, d), dtype=dtype, device="cuda").normal_(
mean=0.0, std=0.5
)
k = torch.empty((nheads, seqlen_q, d), dtype=dtype, device="cuda").normal_(
mean=0.0, std=0.5
)
v = torch.empty((nheads, seqlen_q, d), dtype=dtype, device="cuda").normal_(
mean=0.0, std=0.5
)
sm_scale = 0.5
tri_out = forward(
q.to(device),
k.to(device),
v.to(device),
# pos_emb.to(device),
sm_scale=sm_scale,
).to(dtype)
# reference implementation
dots = torch.matmul(q, k.transpose(-1, -2)) * sm_scale
attn = torch.softmax(
dots.float(), axis=-1
).half()
ref_out = torch.matmul(attn, v.half()).detach().to(dtype).to(device)
# triton implementation
# compare
# assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0)
print("max diff: ", (ref_out - tri_out).abs().max().item())
if __name__ == "__main__":
test_op(8, 4096, 32, torch.float16)
test_op(8, 4096, 32, torch.float16, q_ctx=100)
test_op(800, 4096, 32, torch.float16, q_ctx=10)
test_op(8, 4096, 32, torch.float16, q_ctx=4095)
max diff: 4.57763671875e-05
max diff: 0.060211181640625
max diff: 0.07171630859375
max diff: 0.0134429931640625
And it is very strange that when q_ctx is 4095 and 4096, the error difference is very large (' 0.0134429931640625 'vs.' 4.57763671875e-05 ').
Hi, @zhanglei1172 The standard input format is (batches, nheads, seqlen, d).
Note N_CTX_2 (batches x hiddens x seqence_length) is introduced in the PR to support Hopper TMA. I think the kernel you used is mixing use of both :
-
you loaded K_block_ptr with shape (HIDDEN_DIM, N_CTX_2) in column major layout (Note , the transpose is not necessary if you actually checked the latest implementation, you can just load k slice in row major layout and make an inner dot between q tile and k tile)
-
you loaded V_block_ptr of shape ( N_CTX_2, HIDDEN_DIM) in row major layout
-
however, you load Q_block_ptr : (N_CTX, HIDDEN_DIM)
Look at cuda blocks created :
grid = (triton.cdiv(q.shape[1], BLOCK_M), q.shape[0], 1)
You have created 1*nheads * sequence_length / BLOCK_M blocks
# Q is like a two dimension matrix, with rows **(1 x nheads x sequence_length/BLOCK_M)** and columns **emb_d**, s.t. Q=[q1, q2, ...]_T
# K is like another tow dimension matrix (we don't need to transpose it) with rows **(1 x nheads x sequence_length/BLOCK_N)** and columns **emb_d**, s.t. K = [k1, k2, ...]_T
to calculate causal attention we load
q_i
- iter 0 : k1, v1
- iter 1 : k2, v2
...
- iter hi : k_hi, v_hi (BLOCK_N, emb_d)
Can you check the source code version you used and explain N_CTX_2 in your case?
Hi, I use triton==2.1.0. and N_CTX_2 means the sequence_length of K/V. N_CTX means the sequence_length of Q.
I check the tutorials: https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html and https://github.com/openai/triton/blob/main/python/triton/ops/flash_attention.py#L42 . Both codes use transpose.
And I try to change my second version above:
K_block_ptr = tl.make_block_ptr(
base=K + qvk_offset,
shape=(N_CTX_2, HIDDEN_DIM),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0),
)
qk += tl.dot(q, tl.trans(k), allow_tf32=True)
Results:
max diff: 0.039794921875
max diff: 0.0565185546875
max diff: 0.06787109375
max diff: 0.04058837890625
@zhanglei1172 The code you referred is updated version with patch PR#2336. The definition of N_CTX_2 is here:
https://github.com/openai/triton/blob/72cba380aa64336c90cd98fc9c74cc5a5b205e05/python/triton/ops/flash_attention.py#L390
q.shape[0] * q.shape[1] * q.shape[2]
It is not the sequence_length of K or V or more precisely it treats Q, K, V as 2 dimension matrix, (-1, d), (d, -1), (d, -1), where axis -1 contains values repeated q.shape[0] * q.shape[1] * q.shape[2] times along sequence length direction.
So I think the mixed version of Flash attention v2 (online softmax with memory efficient attention, with Q loaded first in outter loop) and the attention with the support of Hoper TMA is used here.
@yiakwy-xpu-ml-framework-team You mentioned that N_CTX_2 (batches x hiddens x seqence_length) is introduced in the PR to support Hopper TMA, but the version of triton I'm using(or reference) doesn't reference the code containing this PR. Instead, the reason I introduced N_CTX_2(I set it up myself. I didn't refer the code on that PR about Hoper TMA) is for use in scenarios dealing with cross attention (Q and K/V have inconsistent sequence lengths). The code I used intends to differ from the original attention in two main ways: permulation of input shape and the sequence lengths of Q and K/V are different
So I ultimately want to make sure if the current Triton can't support this kind of permulation or cross attention, and if it does, then can I modify the original code to get the results correctly?
I tried the example from 3.1.x branch. There are not N_CTX_2 things, but the answer is still wrong. I modified the test to test a sequence length of 3000, and the max diff to math outputs is as large as 1.7, while the max diff with 3072 tokens is less than 1e-2. I think the kernel is incorrect when QKV is not a multiple of block size.
Need to add several boundary_check and padding_options to tl.load and tl.store instructions, also need to apply a mask for the padding part.