triton icon indicating copy to clipboard operation
triton copied to clipboard

[BUG] Triton compile costs too much time.

Open MeJerry215 opened this issue 10 months ago • 0 comments

TestScripts

import torch
import triton
import triton.language as tl

import math
import sys
torch.manual_seed(42)

import matplotlib.pyplot as plt
import csv
import functools
import torch
import time


def median(lst):
    sorted_lst = sorted(lst)
    n = len(sorted_lst)
    if n % 2 == 0:
        return (sorted_lst[n//2 - 1] + sorted_lst[n//2]) / 2
    else:
        return sorted_lst[n//2]


def timer(options):
    def decorator(func):
        def wrapper(*args, **kwargs):
            clone_list = options.get("clone_list", [])
            inplace_list = options.get("inplace_list", [])
            warm_cnt = options.get("warm_cnt", 20)
            run_cnt = options.get("run_cnt", 50)
            kwargs_p = kwargs.copy()
            clone_args = [ arg for arg in args ]

            for clone_idx in clone_list:
                clone_args[clone_idx] = args[clone_idx].clone()
            for i in range(warm_cnt):
                ta = time.perf_counter()
                results = func(*clone_args, **kwargs_p)
                torch.cuda.synchronize()
                tb = time.perf_counter()
                print(i, (tb - ta) * 1000, "ms")
            t0 = time.perf_counter()
            elapses = []
            for i in range(run_cnt):
                ts = time.perf_counter()
                if i == run_cnt - 1 and len(clone_list) > 0:
                    for clone_idx in clone_list:
                        clone_args[clone_idx] = args[clone_idx].clone()
                results = func(*clone_args, **kwargs_p)
                torch.cuda.synchronize()
                te = time.perf_counter()
                elapse = te - ts
                elapses.append(elapse)

            torch.cuda.synchronize()
            t1 = time.perf_counter()
            tt = median(elapses) * 1000
            if len(inplace_list) == 1:
                results = [clone_args[inplace_list[0]], ]
            elif len(inplace_list) > 1:
                results = [clone_args[idx] for idx in inplace_list]
            else:
                if not isinstance(results, (list, tuple)):
                    results = [results]
            return tt, results
        return wrapper
    return decorator

warm_cnt = 10
run_cnt = 10

def next_power_of_2(n, m=1):
    if n < m:
        return m
    return 2 ** math.ceil(math.log2(n))

@triton.jit
def dot_test_1(a, b, c, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
               M2: tl.constexpr, N2: tl.constexpr, K2: tl.constexpr):
    M_offs = tl.arange(0, M2)
    N_offs = tl.arange(0, N2)
    K_offs = tl.arange(0, K2)
    # M, K
    a_offs = a + M_offs[:, None] * K + K_offs[None, :]
    # N, K
    b_offs = b + K_offs[None, :] * N + N_offs[:, None]
    # M, N
    c_offs = c + M_offs[:, None] * N + N_offs[None, :]
    # M, K
    a_vals = tl.load(a_offs, mask=(M_offs[:, None] < M) & (K_offs[None, :] < K), other=0).to(tl.float32)
    # N, K
    b_vals = tl.load(b_offs, mask=(K_offs[None, :] < K) & (N_offs[:, None] < N)).to(tl.float32)
    # M, 1, K + 1, N, K -> M, N, K
    c_vals = a_vals[:, None, :] * b_vals[None, :, :]
    # M, N
    c_vals = tl.sum(c_vals, axis=2).to(tl.float16)
    tl.store(c_offs, c_vals, mask=(M_offs[:, None] < M) & (N_offs[None, :] < N))
    return

@timer({"warm_cnt": warm_cnt, "run_cnt": run_cnt})
def dot_test_1_wrapper(a, b, c, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,):
    return dot_test_1[1,](a, b, c1, M, N, K, next_power_of_2(M), next_power_of_2(N), next_power_of_2(K), num_stages=2, num_warps=2)

@triton.jit
def dot_test_2(a, b, c, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
                BM: tl.constexpr, BN: tl.constexpr, BK: tl.constexpr):
    M_offs = tl.arange(0, BM) 
    N_offs = tl.arange(0, BN)
    K_offs = tl.arange(0, BK)

    Ns = tl.cdiv(N, BN)
    Ks = tl.cdiv(K, BK)
    for Ni in range(Ns):
        c_vals = tl.zeros((BM, BN), dtype=tl.float32)
        for Ki in range(Ks):
            # M, K
            ak_offs = Ki * BK + K_offs[None, :]
            a_offs = a + M_offs[:, None] * K  + ak_offs
            a_vals = tl.load(a_offs, mask=(M_offs[:, None] < M) & (ak_offs < K), other=0).to(tl.float32)
            # K, N
            bn_offs = Ni * BN + N_offs[None, :]
            bk_offs = (Ki * BK + K_offs[:, None])
            b_vals = tl.load(b + bk_offs * N + bn_offs, mask=(bk_offs < K) & (bn_offs < N), other=0).to(tl.float32)
            c_vals += tl.dot(a_vals, b_vals)
            # M, N
        cn_offs = Ni * BN + N_offs[None, :]
        tl.store(c + M_offs[:, None] * N + cn_offs,  c_vals,  mask=(M_offs[:, None] < M) & (cn_offs < N))
    return


@timer({"warm_cnt": warm_cnt, "run_cnt": run_cnt})
def dot_test_2_wrapper(a, b, c, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,):
    return dot_test_2[1, ](a, b, c, M, N, K, 16, 16, next_power_of_2(K, 16), num_stages=2, num_warps=2)
'''
query (kv_head_groups, head_dim) @ (head_dim,group_size)
key (head_dim,group_size)
value (kv_head_groups, group_size) @ (group_size, head_dim)
'''

from tqdm import tqdm
fig, axs = plt.subplots(5, 2, figsize=(20, 24))

Ms = [1, 2, 4, 8, 16]
for i in range(len(Ms)):
    M = Ms[i]
    for case in [0, 1]:
        ax = axs[i, case]
        tts = []
        tss = []
        for group_size in tqdm(range(1, 513)):
            if case == 0:
                N = 128
                K = group_size
            else:
                N = group_size
                K = 128
            a = torch.randn((M, K), dtype=torch.float16).cuda()
            b = torch.randn((K, N), dtype=torch.float16).cuda()
            c1 = torch.zeros((M, N), dtype=torch.float16).cuda()
            c2 = torch.zeros((M, N), dtype=torch.float16).cuda()
            tt, _ = dot_test_1_wrapper(a, b, c1, M, N, K,)
            ts, _ = dot_test_2_wrapper(a, b, c2, M, N, K,)
            exp = (a.to(torch.float32) @ b.to(torch.float32)).to(torch.float16)
            # import pdb
            # pdb.set_trace()
            print(torch.allclose(c1, exp, atol=1e-3, rtol=1e-3), torch.allclose(c2, exp, atol=1e-3, rtol=1e-3))
            tts.append(tt)
            tss.append(ts)
            print(f"({M}, {K}) @ ({K}, {N})", tt, ts)
    index = list(range(1, 512 + 1))
    ax.plot(index, tts, label="tl.sum", color='blue')
    ax.plot(index, tss, label="tl.dot", color='red')
    ax.set_title(f'({M}, {K}) @ ({K}, {N}) tl.sum vs tl.dot M')
    ax.set_xlabel("group_size")
    ax.set_ylabel('time(ms)')
    ax.legend()
plt.tight_layout()
plt.savefig(f'test.jpg', format='jpg')
print(f"save plot to file test.jpg")



# dot_test_2[[1, ]](a, b, c, M, N, K, next_power_of_2(M, 16), next_power_of_2(N, 16), next_power_of_2(K, 16))
# dot_test_1[[1, ]](a, b, c, M, N, K, next_power_of_2(M), next_power_of_2(N), next_power_of_2(K))
# cq = (a.view(M, 1, K) * b.T.view(1, N, K)).sum(dim=-1)
# print(torch.allclose(c, exp, atol=1e-3, rtol=1e-3))
# import pdb
# pdb.set_trace()

It seems when running at first time, M = 2, N = 128, K = 290 dot_test_1 costs 22787.6 ms,dot_test_2 costs 1065.8 ms.

running time 0.1048 ms vs 0.1748 ms. This is why I test anthor implementation dot_test_1.

Testing device Nvidia A10

The Running Time may print like below:

0 22787.686996161938 ms
1 0.1951158046722412 ms
2 0.122894998639822 ms
3 0.11178990826010704 ms
4 0.11019594967365265 ms
5 0.10815728455781937 ms
6 0.10741734877228737 ms
7 0.10830676183104515 ms
8 0.10813027620315552 ms
9 0.10803062468767166 ms
0 1076.9212269224226 ms
1 0.3412882797420025 ms
2 0.21489430218935013 ms
3 0.19933003932237625 ms
4 0.19005779176950455 ms
5 0.18493272364139557 ms
6 0.18532201647758484 ms
7 0.183924101293087 ms
8 0.18175877630710602 ms
9 0.192372128367424 ms

MeJerry215 avatar Apr 24 '24 08:04 MeJerry215