triton
triton copied to clipboard
[BUG Report] wants full padding dot results but got wrong results.
only call dot in a block with M = 16, N = 1, K = 128, PN = 16
since calling tl.dot
needs M, N, K >=16, padding the N to 16.
A shape is (M, K) = (16, 128)
B shape is (N, K) = (1, 128) , and padding PB shape is (PN, K) = (16, 128)
C shape is (M, N) = (16, 1), and padding PC shape is (M, PN) = (16, 16)
if calling PC = tl.dot(A, tl.trans(PB))
will get (M, PN) matrixs and PC[:, 1:]
will be all zeros, and PC[:, :1]
will be the correct answer, as dump PC without mask.
so Let's try it.
import torch
import triton
import triton.language as tl
torch.manual_seed(42)
M = 16
K = 128
N = 1
PN = 16
a = 0.5 * torch.randn((M, K), dtype=torch.float16).cuda()
b = 0.5 * torch.randn((N, K), dtype=torch.float16).cuda()
c_tmp = torch.zeros((M, 1), dtype=torch.float16).cuda().float()
c_cp = torch.zeros((M, PN), dtype=torch.float16).cuda()
pad_b = torch.zeros((PN, K), dtype=torch.float16).cuda()
pad_b[:N, :] = b
c_pad = a.float() @ pad_b.T.float()
a_cp = torch.ones((M, K), dtype=torch.float16).cuda()
b_cp = torch.ones((K, PN), dtype=torch.float16).cuda()
@triton.jit
def dot_test(a, b, c, a_cp, b_cp, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, PN: tl.constexpr):
'''
a M, K 16x128
b N, K 1 x 128
PN = 16
'''
M_offs = tl.arange(0, M) # 0-15
K_offs = tl.arange(0, K) # 0-127
N_offs = tl.arange(0, PN) # 0-15
a_offs = a + M_offs[:, None] * K + K_offs[None, :]
# 16x128 load with mask
a_vals = tl.load(a_offs, mask=K_offs[None, :] < K, other=0)
b_offs = b + N_offs[None, :] * K + K_offs[:, None]
tl.store(a_cp + M_offs[:, None] * K + K_offs[None, :], a_vals)
# 128x16 mask load [:, :1] valid data, [:, 1:] zero mask load invalid data. and doing transpose
b_vals = tl.load(b_offs, mask=N_offs[None, :] < N, other=0)
tl.store(b_cp + K_offs[:, None] * PN + N_offs[None, :], b_vals)
# 16x16 valid data c_vals[:, :1], invalid data [:, 1:] and should be 0
c_vals = tl.dot(a_vals, b_vals)
# c_vals, tl.max(tl.where(N_offs < N, c_vals, min_val), axis=1)
c_offs = c + M_offs[:, None] * N + N_offs[None, :]
# full save 16x16 c_vals, and c_vals partial [:, :1] results should be valid nums?
tl.store(c_offs, c_vals, )
# tl.store(c_offs, c_vals, mask=N_offs[None, :] < N)
dot_test[1, ](a, b, c_cp, a_cp, b_cp, M, N, K, PN, num_stages=1, num_warps=2)
print("triton c out", c_cp[:, 0])
print("expect c out", c_pad[:, 0])
print("origin b", b.view(-1))
print("triton b", b_cp[:, 0])
print(torch.all(b_cp[:, 1:] == 0))
Above code will generate the folloing console log:
triton c out tensor([-3.0566, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
device='cuda:0', dtype=torch.float16)
expect c out tensor([-3.0574, 3.4058, -2.0936, 2.2738, 0.1777, 1.2245, 0.8552, 1.5280,
-5.8642, -0.4714, -4.4082, 0.8378, 1.5740, -0.0267, -0.8717, 0.0216],
device='cuda:0')
why triton store full c not correct ?