tilelang
tilelang copied to clipboard
[Question] How to improve efficiency for TileLang kernel computing torch.einsum('thd,kd->thk', query, key).sum(dim=1)
Required prerequisites
- [x] I have read the documentation https://tilelang.com.
- [x] I have searched the Issue Tracker that this hasn't already been reported. (comment there if it has.)
Questions
Hi TileLang team,
I'm a new user of TileLang and trying to implement a kernel that computes:
torch.einsum('thd,kd->thk', query, key).sum(dim=1)
Here’s my current kernel code:
import torch
import tilelang
import tilelang.language as T
@tilelang.jit(
out_idx=[-1],
)
def kernel_impl(
t: int,
h: int,
d: int,
block_size: int,
BLOCK_D: int,
dtype: str = "bfloat16",
accum_dtype: str = "float",
):
TILE_D = T.ceildiv(BLOCK_D, block_size)
@T.prim_func
def main(
Q: T.Tensor((t, h, d), dtype),
K: T.Tensor((t, d), dtype),
S: T.Tensor((t, t), dtype),
):
with T.Kernel(t, threads=block_size) as bx:
tk = T.get_thread_binding(0)
K_local = T.alloc_local((TILE_D,), dtype)
Q_local = T.alloc_local((TILE_D,), dtype)
S_accum = T.alloc_local((1,), accum_dtype)
for by in T.Parallel(t):
T.clear(S_accum)
for bh in T.serial(h):
for bk in T.serial(T.ceildiv(d, BLOCK_D)):
for k in T.vectorized(TILE_D):
d_idx = bk * BLOCK_D + tk * TILE_D + k
K_local[k] = K[by, d_idx]
Q_local[k] = Q[bx, bh, d_idx]
for k in T.serial(TILE_D):
S_accum[0] += K_local[k].astype(accum_dtype) * Q_local[k].astype(accum_dtype)
S_reduced = T.alloc_local((1,), accum_dtype)
with T.attr(
T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
"reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"),
):
T.evaluate(
T.tvm_thread_allreduce(
T.uint32(1),
S_accum[0],
True,
S_reduced[0],
tk,
dtype="handle",
))
S[bx, by] = S_reduced[0]
return main
if __name__ == "__main__":
t = 1024*1
h = 16
d = 128
device = "cuda"
print(f"Testing kernel with t={t}, h={h}, d={d}")
query = torch.randn(t, h, d, dtype=torch.bfloat16, device=device)
key = torch.randn(t, d, dtype=torch.bfloat16, device=device)
kernel = kernel_impl(t, h, d, block_size=16, BLOCK_D=128)
logits = kernel(query, key)
# reference
rf_logits_tht = torch.einsum('thd,kd->thk', query, key)
rf_logits = rf_logits_tht.sum(dim=1)
print(f"logits: {logits[0][-5:]}")
print(f"reference logits: {rf_logits[0][-5:]}")
# check diff
diff = torch.abs(logits - rf_logits)
print(f"max diff: {diff.max().item()}")
print(f"mean diff: {diff.mean().item()}")
from tilelang.profiler import do_bench
def run_optimized():
return kernel(query, key)
optimized_ms = do_bench(run_optimized, rep=100, warmup=10)
print(f"Tile-lang Kernel time: {optimized_ms:.3f} ms")
def run_reference():
return torch.einsum('thd,kd->thk', query, key).sum(dim=1)
reference_ms = do_bench(run_reference, rep=100, warmup=10)
print(f"Reference time: {reference_ms:.3f} ms")
print(f"Speedup: {reference_ms / optimized_ms:.2f}x")
Benchmark results (bfloat16 input, float accumulation):
max diff: 1.0
mean diff: 0.05908203125
Tile-lang Kernel time: 0.705 ms
Reference time: 0.034 ms
Speedup: 0.05x
It seems my kernel has:
- Large numerical deviation from the PyTorch reference
- Much slower runtime (0.705 ms vs 0.034 ms)
How to optimize performance so that TileLang kernel can beat PyTorch? Thanks a lot for your guidance!
import torch
import tilelang
import tilelang.language as T
@tilelang.jit(
out_idx=[-1],
)
def kernel_impl(
t: int,
h: int,
d: int,
threads: int,
dtype: str = "bfloat16",
accum_dtype: str = "float",
):
num_stages = 3
BI = 64
NI = tilelang.cdiv(t, BI)
@T.prim_func
def main(
Q: T.Tensor((t, h, d), dtype),
K: T.Tensor((t, d), dtype),
S: T.Tensor((t, t), accum_dtype),
):
with T.Kernel(t, threads=threads) as q_idx:
s_i = q_idx
Q_shared = T.alloc_shared([h, d], dtype)
K_shared = T.alloc_shared([BI, d], dtype)
acc_s = T.alloc_fragment([h, BI], accum_dtype)
acc_s_sum = T.alloc_fragment([BI], accum_dtype)
T.copy(Q[s_i, :, :], Q_shared)
for i_i in T.Pipelined(NI, num_stages=num_stages):
for b_i, d_i in T.Parallel(BI, d):
K_shared[b_i, d_i] = K[i_i * BI + b_i, d_i]
T.gemm(
Q_shared,
K_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow,
clear_accum=True,
)
T.reduce_sum(acc_s, acc_s_sum, dim=0)
T.copy(acc_s_sum, S[s_i, i_i * BI:(i_i + 1) * BI])
return main
if __name__ == "__main__":
t = 1024*8
h = 16
d = 128
device = "cuda"
print(f"Testing kernel with t={t}, h={h}, d={d}")
query = torch.randn(t, h, d, dtype=torch.bfloat16, device=device)
key = torch.randn(t, d, dtype=torch.bfloat16, device=device)
kernel = kernel_impl(t, h, d, threads=64)
logits = kernel(query, key)
# reference
rf_logits_tht = torch.einsum('thd,kd->thk', query, key)
rf_logits = rf_logits_tht.sum(dim=1)
print(f"logits: {logits[0][-5:]}")
print(f"reference logits: {rf_logits[0][-5:]}")
# check diff
diff = torch.abs(logits - rf_logits)
print(f"max diff: {diff.max().item()}")
print(f"mean diff: {diff.mean().item()}")
from tilelang.profiler import do_bench
def run_optimized():
return kernel(query, key)
optimized_ms = do_bench(run_optimized, rep=100, warmup=10)
print(f"Tile-lang Kernel time: {optimized_ms:.3f} ms")
def run_reference():
return torch.einsum('thd,kd->thk', query, key).sum(dim=1)
reference_ms = do_bench(run_reference, rep=100, warmup=10)
print(f"Reference time: {reference_ms:.3f} ms")
print(f"Speedup: {reference_ms / optimized_ms:.2f}x")
This's one is better, but still left behind of:
torch.einsum('thd,kd->thk', query, key).sum(dim=1)
max diff: 0.881103515625
mean diff: 0.08123190701007843
Tile-lang Kernel time: 2.876 ms
Reference time: 1.627 ms
Speedup: 0.57x