triton
triton copied to clipboard
[BUG] Triton compile costs too much time.
TestScripts
import torch
import triton
import triton.language as tl
import math
import sys
torch.manual_seed(42)
import matplotlib.pyplot as plt
import csv
import functools
import torch
import time
def median(lst):
sorted_lst = sorted(lst)
n = len(sorted_lst)
if n % 2 == 0:
return (sorted_lst[n//2 - 1] + sorted_lst[n//2]) / 2
else:
return sorted_lst[n//2]
def timer(options):
def decorator(func):
def wrapper(*args, **kwargs):
clone_list = options.get("clone_list", [])
inplace_list = options.get("inplace_list", [])
warm_cnt = options.get("warm_cnt", 20)
run_cnt = options.get("run_cnt", 50)
kwargs_p = kwargs.copy()
clone_args = [ arg for arg in args ]
for clone_idx in clone_list:
clone_args[clone_idx] = args[clone_idx].clone()
for i in range(warm_cnt):
ta = time.perf_counter()
results = func(*clone_args, **kwargs_p)
torch.cuda.synchronize()
tb = time.perf_counter()
print(i, (tb - ta) * 1000, "ms")
t0 = time.perf_counter()
elapses = []
for i in range(run_cnt):
ts = time.perf_counter()
if i == run_cnt - 1 and len(clone_list) > 0:
for clone_idx in clone_list:
clone_args[clone_idx] = args[clone_idx].clone()
results = func(*clone_args, **kwargs_p)
torch.cuda.synchronize()
te = time.perf_counter()
elapse = te - ts
elapses.append(elapse)
torch.cuda.synchronize()
t1 = time.perf_counter()
tt = median(elapses) * 1000
if len(inplace_list) == 1:
results = [clone_args[inplace_list[0]], ]
elif len(inplace_list) > 1:
results = [clone_args[idx] for idx in inplace_list]
else:
if not isinstance(results, (list, tuple)):
results = [results]
return tt, results
return wrapper
return decorator
warm_cnt = 10
run_cnt = 10
def next_power_of_2(n, m=1):
if n < m:
return m
return 2 ** math.ceil(math.log2(n))
@triton.jit
def dot_test_1(a, b, c, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
M2: tl.constexpr, N2: tl.constexpr, K2: tl.constexpr):
M_offs = tl.arange(0, M2)
N_offs = tl.arange(0, N2)
K_offs = tl.arange(0, K2)
# M, K
a_offs = a + M_offs[:, None] * K + K_offs[None, :]
# N, K
b_offs = b + K_offs[None, :] * N + N_offs[:, None]
# M, N
c_offs = c + M_offs[:, None] * N + N_offs[None, :]
# M, K
a_vals = tl.load(a_offs, mask=(M_offs[:, None] < M) & (K_offs[None, :] < K), other=0).to(tl.float32)
# N, K
b_vals = tl.load(b_offs, mask=(K_offs[None, :] < K) & (N_offs[:, None] < N)).to(tl.float32)
# M, 1, K + 1, N, K -> M, N, K
c_vals = a_vals[:, None, :] * b_vals[None, :, :]
# M, N
c_vals = tl.sum(c_vals, axis=2).to(tl.float16)
tl.store(c_offs, c_vals, mask=(M_offs[:, None] < M) & (N_offs[None, :] < N))
return
@timer({"warm_cnt": warm_cnt, "run_cnt": run_cnt})
def dot_test_1_wrapper(a, b, c, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,):
return dot_test_1[1,](a, b, c1, M, N, K, next_power_of_2(M), next_power_of_2(N), next_power_of_2(K), num_stages=2, num_warps=2)
@triton.jit
def dot_test_2(a, b, c, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
BM: tl.constexpr, BN: tl.constexpr, BK: tl.constexpr):
M_offs = tl.arange(0, BM)
N_offs = tl.arange(0, BN)
K_offs = tl.arange(0, BK)
Ns = tl.cdiv(N, BN)
Ks = tl.cdiv(K, BK)
for Ni in range(Ns):
c_vals = tl.zeros((BM, BN), dtype=tl.float32)
for Ki in range(Ks):
# M, K
ak_offs = Ki * BK + K_offs[None, :]
a_offs = a + M_offs[:, None] * K + ak_offs
a_vals = tl.load(a_offs, mask=(M_offs[:, None] < M) & (ak_offs < K), other=0).to(tl.float32)
# K, N
bn_offs = Ni * BN + N_offs[None, :]
bk_offs = (Ki * BK + K_offs[:, None])
b_vals = tl.load(b + bk_offs * N + bn_offs, mask=(bk_offs < K) & (bn_offs < N), other=0).to(tl.float32)
c_vals += tl.dot(a_vals, b_vals)
# M, N
cn_offs = Ni * BN + N_offs[None, :]
tl.store(c + M_offs[:, None] * N + cn_offs, c_vals, mask=(M_offs[:, None] < M) & (cn_offs < N))
return
@timer({"warm_cnt": warm_cnt, "run_cnt": run_cnt})
def dot_test_2_wrapper(a, b, c, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,):
return dot_test_2[1, ](a, b, c, M, N, K, 16, 16, next_power_of_2(K, 16), num_stages=2, num_warps=2)
'''
query (kv_head_groups, head_dim) @ (head_dim,group_size)
key (head_dim,group_size)
value (kv_head_groups, group_size) @ (group_size, head_dim)
'''
from tqdm import tqdm
fig, axs = plt.subplots(5, 2, figsize=(20, 24))
Ms = [1, 2, 4, 8, 16]
for i in range(len(Ms)):
M = Ms[i]
for case in [0, 1]:
ax = axs[i, case]
tts = []
tss = []
for group_size in tqdm(range(1, 513)):
if case == 0:
N = 128
K = group_size
else:
N = group_size
K = 128
a = torch.randn((M, K), dtype=torch.float16).cuda()
b = torch.randn((K, N), dtype=torch.float16).cuda()
c1 = torch.zeros((M, N), dtype=torch.float16).cuda()
c2 = torch.zeros((M, N), dtype=torch.float16).cuda()
tt, _ = dot_test_1_wrapper(a, b, c1, M, N, K,)
ts, _ = dot_test_2_wrapper(a, b, c2, M, N, K,)
exp = (a.to(torch.float32) @ b.to(torch.float32)).to(torch.float16)
# import pdb
# pdb.set_trace()
print(torch.allclose(c1, exp, atol=1e-3, rtol=1e-3), torch.allclose(c2, exp, atol=1e-3, rtol=1e-3))
tts.append(tt)
tss.append(ts)
print(f"({M}, {K}) @ ({K}, {N})", tt, ts)
index = list(range(1, 512 + 1))
ax.plot(index, tts, label="tl.sum", color='blue')
ax.plot(index, tss, label="tl.dot", color='red')
ax.set_title(f'({M}, {K}) @ ({K}, {N}) tl.sum vs tl.dot M')
ax.set_xlabel("group_size")
ax.set_ylabel('time(ms)')
ax.legend()
plt.tight_layout()
plt.savefig(f'test.jpg', format='jpg')
print(f"save plot to file test.jpg")
# dot_test_2[[1, ]](a, b, c, M, N, K, next_power_of_2(M, 16), next_power_of_2(N, 16), next_power_of_2(K, 16))
# dot_test_1[[1, ]](a, b, c, M, N, K, next_power_of_2(M), next_power_of_2(N), next_power_of_2(K))
# cq = (a.view(M, 1, K) * b.T.view(1, N, K)).sum(dim=-1)
# print(torch.allclose(c, exp, atol=1e-3, rtol=1e-3))
# import pdb
# pdb.set_trace()
It seems when running at first time, M = 2, N = 128, K = 290 dot_test_1
costs 22787.6 ms,dot_test_2
costs 1065.8 ms.
running time 0.1048 ms vs 0.1748 ms.
This is why I test anthor implementation dot_test_1
.
Testing device Nvidia A10
The Running Time may print like below:
0 22787.686996161938 ms
1 0.1951158046722412 ms
2 0.122894998639822 ms
3 0.11178990826010704 ms
4 0.11019594967365265 ms
5 0.10815728455781937 ms
6 0.10741734877228737 ms
7 0.10830676183104515 ms
8 0.10813027620315552 ms
9 0.10803062468767166 ms
0 1076.9212269224226 ms
1 0.3412882797420025 ms
2 0.21489430218935013 ms
3 0.19933003932237625 ms
4 0.19005779176950455 ms
5 0.18493272364139557 ms
6 0.18532201647758484 ms
7 0.183924101293087 ms
8 0.18175877630710602 ms
9 0.192372128367424 ms