triton icon indicating copy to clipboard operation
triton copied to clipboard

[BUG Report] wants full padding dot results but got wrong results.

Open MeJerry215 opened this issue 9 months ago • 0 comments

only call dot in a block with M = 16, N = 1, K = 128, PN = 16 since calling tl.dot needs M, N, K >=16, padding the N to 16.
A shape is (M, K) = (16, 128) B shape is (N, K) = (1, 128) , and padding PB shape is (PN, K) = (16, 128) C shape is (M, N) = (16, 1), and padding PC shape is (M, PN) = (16, 16)

if calling PC = tl.dot(A, tl.trans(PB)) will get (M, PN) matrixs and PC[:, 1:] will be all zeros, and PC[:, :1] will be the correct answer, as dump PC without mask.

so Let's try it.


import torch
import triton
import triton.language as tl

torch.manual_seed(42)
M = 16
K = 128
N = 1
PN = 16
a = 0.5 * torch.randn((M, K), dtype=torch.float16).cuda()
b = 0.5 * torch.randn((N, K), dtype=torch.float16).cuda()
c_tmp = torch.zeros((M, 1), dtype=torch.float16).cuda().float()
c_cp = torch.zeros((M, PN), dtype=torch.float16).cuda()
pad_b = torch.zeros((PN, K), dtype=torch.float16).cuda()
pad_b[:N, :] = b
c_pad = a.float() @ pad_b.T.float()

a_cp = torch.ones((M, K), dtype=torch.float16).cuda()
b_cp = torch.ones((K, PN), dtype=torch.float16).cuda()


@triton.jit
def dot_test(a, b, c, a_cp, b_cp, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, PN: tl.constexpr):
    '''
    a M, K   16x128
    b N, K   1 x 128
    PN = 16
    '''
    M_offs = tl.arange(0, M)    # 0-15
    K_offs = tl.arange(0, K)    # 0-127
    N_offs = tl.arange(0, PN)   # 0-15
    a_offs = a + M_offs[:, None] * K + K_offs[None, :]
    # 16x128 load with mask
    a_vals = tl.load(a_offs, mask=K_offs[None, :] < K, other=0)
    b_offs = b + N_offs[None, :] * K + K_offs[:, None]
    tl.store(a_cp + M_offs[:, None] * K + K_offs[None, :], a_vals)
    # 128x16 mask load [:, :1] valid data, [:, 1:] zero mask load invalid data. and doing transpose
    b_vals = tl.load(b_offs, mask=N_offs[None, :] < N, other=0)
    tl.store(b_cp + K_offs[:, None] * PN + N_offs[None, :], b_vals)
    # 16x16 valid data c_vals[:, :1], invalid data [:, 1:] and should be 0
    c_vals = tl.dot(a_vals, b_vals)
    # c_vals, tl.max(tl.where(N_offs < N, c_vals, min_val), axis=1)
    c_offs = c + M_offs[:, None] * N + N_offs[None, :]
    # full save 16x16 c_vals, and c_vals partial [:, :1] results should be valid nums?
    tl.store(c_offs, c_vals, )
    # tl.store(c_offs, c_vals, mask=N_offs[None, :] < N)


dot_test[1, ](a, b, c_cp, a_cp, b_cp, M, N, K, PN, num_stages=1, num_warps=2)
print("triton c out", c_cp[:, 0])
print("expect c out", c_pad[:, 0])
print("origin b", b.view(-1))
print("triton b", b_cp[:, 0])
print(torch.all(b_cp[:, 1:] == 0))

Above code will generate the folloing console log:

triton c out tensor([-3.0566,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
       device='cuda:0', dtype=torch.float16)
expect c out tensor([-3.0574,  3.4058, -2.0936,  2.2738,  0.1777,  1.2245,  0.8552,  1.5280,
        -5.8642, -0.4714, -4.4082,  0.8378,  1.5740, -0.0267, -0.8717,  0.0216],
       device='cuda:0')

why triton store full c not correct ?

MeJerry215 avatar Apr 30 '24 07:04 MeJerry215