tilelang icon indicating copy to clipboard operation
tilelang copied to clipboard

[Question] How to improve efficiency for TileLang kernel computing torch.einsum('thd,kd->thk', query, key).sum(dim=1)

Open CSQianDong opened this issue 1 month ago • 1 comments

Required prerequisites

Questions

Hi TileLang team,

I'm a new user of TileLang and trying to implement a kernel that computes: torch.einsum('thd,kd->thk', query, key).sum(dim=1)

Here’s my current kernel code:

import torch
import tilelang
import tilelang.language as T

@tilelang.jit(
    out_idx=[-1],
)
def kernel_impl(
    t: int,
    h: int,
    d: int,
    block_size: int,
    BLOCK_D: int,
    dtype: str = "bfloat16",
    accum_dtype: str = "float",
):
    
    TILE_D = T.ceildiv(BLOCK_D, block_size)

    @T.prim_func
    def main(
            Q: T.Tensor((t, h, d), dtype),
            K: T.Tensor((t, d), dtype),
            S: T.Tensor((t, t), dtype),
    ):
        with T.Kernel(t, threads=block_size) as bx:
            tk = T.get_thread_binding(0)
            K_local = T.alloc_local((TILE_D,), dtype)
            Q_local = T.alloc_local((TILE_D,), dtype)
            S_accum = T.alloc_local((1,), accum_dtype)
            for by in T.Parallel(t):
                T.clear(S_accum)
                for bh in T.serial(h):
                    for bk in T.serial(T.ceildiv(d, BLOCK_D)):
                        for k in T.vectorized(TILE_D):
                            d_idx = bk * BLOCK_D + tk * TILE_D + k
                            K_local[k] = K[by, d_idx] 
                            Q_local[k] = Q[bx, bh, d_idx]
                        for k in T.serial(TILE_D):
                            S_accum[0] += K_local[k].astype(accum_dtype) * Q_local[k].astype(accum_dtype)
                S_reduced = T.alloc_local((1,), accum_dtype)
                with T.attr(
                        T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
                        "reduce_scope",
                        T.reinterpret(T.uint64(0), dtype="handle"),
                ):
                    T.evaluate(
                        T.tvm_thread_allreduce(
                            T.uint32(1),
                            S_accum[0],
                            True,
                            S_reduced[0],
                            tk,
                            dtype="handle",
                        ))
                S[bx, by] = S_reduced[0]

    return main


if __name__ == "__main__":
    t = 1024*1
    h = 16       
    d = 128
    device = "cuda"
    
    print(f"Testing kernel with t={t}, h={h}, d={d}")
    query = torch.randn(t, h, d, dtype=torch.bfloat16, device=device)
    key = torch.randn(t, d, dtype=torch.bfloat16, device=device)
    
    kernel = kernel_impl(t, h, d, block_size=16, BLOCK_D=128)
    logits = kernel(query, key)
    
    # reference
    rf_logits_tht = torch.einsum('thd,kd->thk', query, key)
    rf_logits = rf_logits_tht.sum(dim=1)
    
    print(f"logits: {logits[0][-5:]}")
    print(f"reference logits: {rf_logits[0][-5:]}")
    
    # check diff
    diff = torch.abs(logits - rf_logits) 
    
    print(f"max diff: {diff.max().item()}")
    print(f"mean diff: {diff.mean().item()}")
    from tilelang.profiler import do_bench
    def run_optimized():
        return kernel(query, key)
    optimized_ms = do_bench(run_optimized, rep=100, warmup=10)
    print(f"Tile-lang Kernel time: {optimized_ms:.3f} ms")
    def run_reference():
        return torch.einsum('thd,kd->thk', query, key).sum(dim=1)
    reference_ms = do_bench(run_reference, rep=100, warmup=10)
    print(f"Reference time: {reference_ms:.3f} ms")
    print(f"Speedup: {reference_ms / optimized_ms:.2f}x")

Benchmark results (bfloat16 input, float accumulation):

max diff: 1.0
mean diff: 0.05908203125
Tile-lang Kernel time: 0.705 ms
Reference time: 0.034 ms
Speedup: 0.05x

It seems my kernel has:

  • Large numerical deviation from the PyTorch reference
  • Much slower runtime (0.705 ms vs 0.034 ms)

How to optimize performance so that TileLang kernel can beat PyTorch? Thanks a lot for your guidance!

CSQianDong avatar Nov 21 '25 16:11 CSQianDong

import torch
import tilelang
import tilelang.language as T

@tilelang.jit(
    out_idx=[-1],
)
def kernel_impl(
    t: int,
    h: int,
    d: int,
    threads: int,
    dtype: str = "bfloat16",
    accum_dtype: str = "float",
):
    num_stages = 3
    BI = 64
    NI = tilelang.cdiv(t, BI)
    @T.prim_func
    def main(
            Q: T.Tensor((t, h, d), dtype),
            K: T.Tensor((t, d), dtype),
            S: T.Tensor((t, t), accum_dtype),
    ):
        with T.Kernel(t, threads=threads) as q_idx:
            s_i = q_idx
            Q_shared = T.alloc_shared([h, d], dtype)
            K_shared = T.alloc_shared([BI, d], dtype)
            acc_s = T.alloc_fragment([h, BI], accum_dtype)
            acc_s_sum = T.alloc_fragment([BI], accum_dtype)

            T.copy(Q[s_i, :, :], Q_shared)
            for i_i in T.Pipelined(NI, num_stages=num_stages):
                for b_i, d_i in T.Parallel(BI, d):
                    K_shared[b_i, d_i] = K[i_i * BI + b_i, d_i]
                T.gemm(
                        Q_shared,
                        K_shared,
                        acc_s,
                        transpose_B=True,
                        policy=T.GemmWarpPolicy.FullRow,
                        clear_accum=True,
                )
                T.reduce_sum(acc_s, acc_s_sum, dim=0)
                T.copy(acc_s_sum, S[s_i, i_i * BI:(i_i + 1) * BI])
    return main


if __name__ == "__main__":
    t = 1024*8
    h = 16       
    d = 128
    device = "cuda"
    
    print(f"Testing kernel with t={t}, h={h}, d={d}")
    query = torch.randn(t, h, d, dtype=torch.bfloat16, device=device)
    key = torch.randn(t, d, dtype=torch.bfloat16, device=device)
    
    kernel = kernel_impl(t, h, d, threads=64)
    logits = kernel(query, key)
    
    # reference
    rf_logits_tht = torch.einsum('thd,kd->thk', query, key)
    rf_logits = rf_logits_tht.sum(dim=1)
    
    print(f"logits: {logits[0][-5:]}")
    print(f"reference logits: {rf_logits[0][-5:]}")
    
    # check diff
    diff = torch.abs(logits - rf_logits) 
    
    print(f"max diff: {diff.max().item()}")
    print(f"mean diff: {diff.mean().item()}")
    from tilelang.profiler import do_bench
    def run_optimized():
        return kernel(query, key)
    optimized_ms = do_bench(run_optimized, rep=100, warmup=10)
    print(f"Tile-lang Kernel time: {optimized_ms:.3f} ms")
    def run_reference():
        return torch.einsum('thd,kd->thk', query, key).sum(dim=1)
    reference_ms = do_bench(run_reference, rep=100, warmup=10)
    print(f"Reference time: {reference_ms:.3f} ms")
    print(f"Speedup: {reference_ms / optimized_ms:.2f}x")

This's one is better, but still left behind of:

torch.einsum('thd,kd->thk', query, key).sum(dim=1)
max diff: 0.881103515625
mean diff: 0.08123190701007843
Tile-lang Kernel time: 2.876 ms
Reference time: 1.627 ms
Speedup: 0.57x

CSQianDong avatar Nov 21 '25 17:11 CSQianDong