tilelang icon indicating copy to clipboard operation
tilelang copied to clipboard

What's the C++ compile and calling API of tilelang kernel?

Open MoFHeka opened this issue 6 months ago • 5 comments

In XLA, it can be used PJRT API to access XLA kernel in CPP code, which is the implementation of PyTorch XLA backend. It's there any way to access Tile-lang core API to load, compile and run "python code"?

MoFHeka avatar Jun 24 '25 12:06 MoFHeka

@MoFHeka Thank you for your attention. You can obtain the compiled CUDA kernel code using the get_kernel_source interface. Additionally, the compiled library is available as kernel.lib, hope this helps.

# Copyright (c) Tile-AI Corporation.
# Licensed under the MIT License.
import tilelang
import tilelang.language as T
# `make_mma_swizzle_layout` is a python defined layout function
# specifically designed for MMA operations
# which ensures the consistency with the nvidia CUTLASS Library.
# to avoid bank conflicts and maximize the performance.
from tilelang.intrinsics import (
    make_mma_swizzle_layout as make_swizzle_layout,)  # noqa: F401


# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):

    @T.prim_func
    def main(
            A: T.Tensor((M, K), dtype),
            B: T.Tensor((K, N), dtype),
            C: T.Tensor((M, N), dtype),
    ):
        # Initialize Kernel Context
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
            A_shared = T.alloc_shared((block_M, block_K), dtype)
            B_shared = T.alloc_shared((block_K, block_N), dtype)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

            # Apply layout optimizations or define your own layout (Optional)
            # If not specified, we will deduce the layout automatically
            # T.annotate_layout({
            #     A_shared: make_swizzle_layout(A_shared),
            #     B_shared: make_swizzle_layout(B_shared),
            # })

            # Enable rasterization for better L2 cache locality (Optional)
            # T.use_swizzle(panel_size=10, enable=True)

            # Clear local accumulation
            T.clear(C_local)

            for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
                # Copy tile of A
                # This is a sugar syntax for parallelized copy
                # for i, k in T.Parallel(M, block_K):
                #     A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
                T.copy(A[by * block_M, ko * block_K], A_shared)

                # Copy tile of B
                T.copy(B[ko * block_K, bx * block_N], B_shared)

                # Perform a tile-level GEMM on the shared buffers
                # Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
                T.gemm(A_shared, B_shared, C_local)

            # Copy result back to global memory
            T.copy(C_local, C[by * block_M, bx * block_N])

    return main


M = 1024  # M = T.symbolic("m") if you want to use dynamic shape
N = 1024
K = 1024
block_M = 128
block_N = 128
block_K = 32

func = matmul(M, N, K, block_M, block_N, block_K)

jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda")

print(jit_kernel.get_kernel_source())
# kernel source

print(jit_kernel.adapter.libpath)
# "/home/wanglei/.tilelang/cache/55461c5b9f24c64e4e85d4dc533e6ca66e6e59efb6206f4c3d6b20065b602170/kernel_lib.so"
print(jit_kernel.adapter.lib) # from ctypes.DLL
# <CDLL '/home/wanglei/.tilelang/cache/55461c5b9f24c64e4e85d4dc533e6ca66e6e59efb6206f4c3d6b20065b602170/kernel_lib.so', handle 8164ac0 at 0x7fe4a2449b80>

LeiWang1999 avatar Jun 24 '25 12:06 LeiWang1999

@LeiWang1999 Thank you for replying, where could I include the header. I checked tilelang/jit/adapter/libgen.py, and there is only lib path. Is that COMPOSABLE_KERNEL_INCLUDE_DIR?

MoFHeka avatar Jun 24 '25 13:06 MoFHeka

@MoFHeka Unfortunately no, we haven't generated the corresponding header files for the kernel source because TileLang's JIT doesn't actually need these files. However, I think we could easily implement this feature. It would require modifying this file: https://github.com/tile-ai/tilelang/blob/main/tilelang/jit/adapter/wrapper.py

I'm not entirely sure what changes would be needed to meet your requirements. If you're interested in working on this, we'd be happy to assist you with the implementation."

LeiWang1999 avatar Jun 24 '25 14:06 LeiWang1999

@LeiWang1999 Header is not necessary. As long as I have any method that can run the corresponding global function, that's fine. Can I find the symbol of the generated function in wrapper.py?

MoFHeka avatar Jun 25 '25 01:06 MoFHeka

@LeiWang1999 yeah for sure, and actually when take a look at kernel.get_kernel_source(), we can find some functions

#include <math_constants.h>
#include <tl_templates/cuda/gemm.h>
#include <tl_templates/cuda/copy.h>
#include <tl_templates/cuda/reduce.h>
#include <tl_templates/cuda/ldsm.h>
#include <tl_templates/cuda/threadblock_swizzle.h>
#include <tl_templates/cuda/debug.h>

extern "C" __global__ void main_kernel(float* __restrict__ Delta, int* __restrict__ Indices, bfloat16_t* __restrict__ KV, float* __restrict__ Lse, bfloat16_t* __restrict__ Q, bfloat16_t* __restrict__ Q_d, bfloat16_t* __restrict__ dO, bfloat16_t* __restrict__ dQ);
extern "C" __global__ void __launch_bounds__(128, 1) main_kernel(float* __restrict__ Delta, int* __restrict__ Indices, bfloat16_t* __restrict__ KV, float* __restrict__ Lse, bfloat16_t* __restrict__ Q, bfloat16_t* __restrict__ Q_d, bfloat16_t* __restrict__ dO, bfloat16_t* __restrict__ dQ) {
  extern __shared__ __align__(1024) uchar buf_dyn_shmem[];
  float lse[2];
  float delta[2];
  float acc_dq[256];
  float acc_dq_tail[32];
  signed char mask[8];
  float acc_p[16];
  float acc_dp[16];
  #pragma unroll
  for (int i = 0; i < 32; ++i) {
    *(uint4*)(((bfloat16_t*)buf_dyn_shmem) + (((((((((((int)threadIdx.x) & 63) >> 3) * 4096) + (i * 128)) + ((((int)threadIdx.x) >> 6) * 64)) + (((((((int)threadIdx.x) & 7) >> 2) + ((i & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 3) >> 1) + (i & 1)) & 1) * 16)) + ((((((int)threadIdx.x) >> 6) + (((int)threadIdx.x) & 1)) & 1) * 8)) + 26624)) = *(uint4*)(Q + ((((((int)blockIdx.x) * 36864) + (i * 1152)) + ((((int)threadIdx.x) >> 6) * 576)) + ((((int)threadIdx.x) & 63) * 8)));
  }
  #pragma unroll
  for (int i_1 = 0; i_1 < 4; ++i_1) {
    *(uint4*)(((bfloat16_t*)buf_dyn_shmem) + ((((((i_1 * 1024) + ((((int)threadIdx.x) >> 3) * 64)) + (((((((int)threadIdx.x) & 63) >> 5) + ((((int)threadIdx.x) & 7) >> 2)) & 1) * 32)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 16)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 8)) + 6144)) = *(uint4*)(Q + (((((((int)blockIdx.x) * 36864) + (i_1 * 9216)) + ((((int)threadIdx.x) >> 3) * 576)) + ((((int)threadIdx.x) & 7) * 8)) + 512));
  }
  #pragma unroll
  for (int i_2 = 0; i_2 < 8; ++i_2) {
    *(uint4*)(((bfloat16_t*)buf_dyn_shmem) + (((i_2 * 1024) + (((int)threadIdx.x) * 8)) + 59392)) = *(uint4*)(Q_d + (((((int)blockIdx.x) * 8192) + (i_2 * 1024)) + (((int)threadIdx.x) * 8)));
  }
  #pragma unroll
  for (int i_3 = 0; i_3 < 32; ++i_3) {
    *(uint4*)(((bfloat16_t*)buf_dyn_shmem) + (((((((((((int)threadIdx.x) & 63) >> 3) * 4096) + (i_3 * 128)) + ((((int)threadIdx.x) >> 6) * 64)) + (((((((int)threadIdx.x) & 7) >> 2) + ((i_3 & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 3) >> 1) + (i_3 & 1)) & 1) * 16)) + ((((((int)threadIdx.x) >> 6) + (((int)threadIdx.x) & 1)) & 1) * 8)) + 59392)) = *(uint4*)(dO + (((((int)blockIdx.x) * 32768) + (i_3 * 1024)) + (((int)threadIdx.x) * 8)));
  }
  #pragma unroll
  for (int i_4 = 0; i_4 < 2; ++i_4) {
    lse[i_4] = Lse[((((((int)blockIdx.x) * 64) + ((((int)threadIdx.x) >> 5) * 16)) + (i_4 * 8)) + ((((int)threadIdx.x) & 31) >> 2))];
  }
  #pragma unroll
  for (int i_5 = 0; i_5 < 2; ++i_5) {
    delta[i_5] = Delta[((((((int)blockIdx.x) * 64) + ((((int)threadIdx.x) >> 5) * 16)) + (i_5 * 8)) + ((((int)threadIdx.x) & 31) >> 2))];
  }
  #pragma unroll
  for (int i_6 = 0; i_6 < 128; ++i_6) {
    *(float2*)(acc_dq + (i_6 * 2)) = make_float2(0.000000e+00f, 0.000000e+00f);
  }
  #pragma unroll
  for (int i_7 = 0; i_7 < 16; ++i_7) {
    *(float2*)(acc_dq_tail + (i_7 * 2)) = make_float2(0.000000e+00f, 0.000000e+00f);
  }
  for (int i_i = 0; i_i < 16; ++i_i) {
    #pragma unroll
    for (int i_8 = 0; i_8 < 4; ++i_8) {
      char2 __1;
      ushort2 __2;
        int2 v_ = *(int2*)(Indices + ((((((int)blockIdx.x) * 512) + (i_i * 32)) + (i_8 * 8)) + ((((int)threadIdx.x) & 3) * 2)));
        int2 v__1 = make_int2(((((int)blockIdx.x) + 4093) >> 2), ((((int)blockIdx.x) + 4093) >> 2));
        __2.x = (v_.x<=v__1.x);
        __2.y = (v_.y<=v__1.y);
      __1.x=((signed char)(__2.x));
      __1.y=((signed char)(__2.y));
      *(char2*)(mask + (i_8 * 2)) = __1;
    }
    __syncthreads();
    #pragma unroll
    for (int i_9 = 0; i_9 < 16; ++i_9) {
      uint4 condval;
      if (((0 <= Indices[((((((int)blockIdx.x) * 512) + (i_i * 32)) + (i_9 * 2)) + (((int)threadIdx.x) >> 6))]) && (Indices[((((((int)blockIdx.x) * 512) + (i_i * 32)) + (i_9 * 2)) + (((int)threadIdx.x) >> 6))] < 2048))) {
        condval = *(uint4*)(KV + ((((int64_t)Indices[((((((int64_t)((int)blockIdx.x)) * (int64_t)512) + (((int64_t)i_i) * (int64_t)32)) + (((int64_t)i_9) * (int64_t)2)) + (((int64_t)((int)threadIdx.x)) >> (int64_t)6))]) * (int64_t)576) + ((((int64_t)((int)threadIdx.x)) & (int64_t)63) * (int64_t)8)));
      } else {
        condval = make_uint4(__pack_nv_bfloat162(bfloat16_t(0.000000e+00f), bfloat16_t(0.000000e+00f)), __pack_nv_bfloat162(bfloat16_t(0.000000e+00f), bfloat16_t(0.000000e+00f)), __pack_nv_bfloat162(bfloat16_t(0.000000e+00f), bfloat16_t(0.000000e+00f)), __pack_nv_bfloat162(bfloat16_t(0.000000e+00f), bfloat16_t(0.000000e+00f)));
      }
      *(uint4*)(((bfloat16_t*)buf_dyn_shmem) + (((((((((((int)threadIdx.x) & 63) >> 3) * 2048) + (i_9 * 128)) + ((((int)threadIdx.x) >> 6) * 64)) + (((((((int)threadIdx.x) & 7) >> 2) + ((i_9 & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 3) >> 1) + (i_9 & 1)) & 1) * 16)) + ((((((int)threadIdx.x) >> 6) + (((int)threadIdx.x) & 1)) & 1) * 8)) + 10240)) = condval;
    }
    #pragma unroll
    for (int i_10 = 0; i_10 < 2; ++i_10) {
      uint4 condval_1;
      if (((0 <= Indices[((((((int)blockIdx.x) * 512) + (i_i * 32)) + (i_10 * 16)) + (((int)threadIdx.x) >> 3))]) && (Indices[((((((int)blockIdx.x) * 512) + (i_i * 32)) + (i_10 * 16)) + (((int)threadIdx.x) >> 3))] < 2048))) {
        condval_1 = *(uint4*)(KV + (((((int64_t)Indices[((((((int64_t)((int)blockIdx.x)) * (int64_t)512) + (((int64_t)i_i) * (int64_t)32)) + (((int64_t)i_10) * (int64_t)16)) + (((int64_t)((int)threadIdx.x)) >> (int64_t)3))]) * (int64_t)576) + ((((int64_t)((int)threadIdx.x)) & (int64_t)7) * (int64_t)8)) + (int64_t)512));
      } else {
        condval_1 = make_uint4(__pack_nv_bfloat162(bfloat16_t(0.000000e+00f), bfloat16_t(0.000000e+00f)), __pack_nv_bfloat162(bfloat16_t(0.000000e+00f), bfloat16_t(0.000000e+00f)), __pack_nv_bfloat162(bfloat16_t(0.000000e+00f), bfloat16_t(0.000000e+00f)), __pack_nv_bfloat162(bfloat16_t(0.000000e+00f), bfloat16_t(0.000000e+00f)));
      }
      *(uint4*)(((bfloat16_t*)buf_dyn_shmem) + (((((i_10 * 1024) + ((((int)threadIdx.x) >> 3) * 64)) + (((((((int)threadIdx.x) & 63) >> 5) + ((((int)threadIdx.x) & 7) >> 2)) & 1) * 32)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 16)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 8))) = condval_1;
    }
    #pragma unroll
    for (int i_11 = 0; i_11 < 16; ++i_11) {
      float condval_2;
      if (((bool)mask[(((i_11 >> 2) * 2) + (i_11 & 1))])) {
        condval_2 = 0.000000e+00f;
      } else {
        condval_2 = -CUDART_INF_F;
      }
      acc_p[i_11] = condval_2;
    }
    tl::fence_proxy_async();
    __syncthreads();
    tl::gemm_ss<64, 32, 512, 4, 1, 0, 1, 0, true>((&(((bfloat16_t*)buf_dyn_shmem)[26624])), (&(((bfloat16_t*)buf_dyn_shmem)[10240])), (&(acc_p[0])));
    tl::gemm_ss<64, 32, 64, 4, 1, 0, 1, 0, true>((&(((bfloat16_t*)buf_dyn_shmem)[6144])), (&(((bfloat16_t*)buf_dyn_shmem)[0])), (&(acc_p[0])));
    #pragma unroll
    for (int i_12 = 0; i_12 < 16; ++i_12) {
      acc_p[i_12] = exp2f(((acc_p[i_12] * 6.011229e-02f) - lse[((i_12 & 3) >> 1)]));
    }
    __syncthreads();
    #pragma unroll
    for (int i_13 = 0; i_13 < 2; ++i_13) {
      tl::ptx_stmatrix_x4((&(((bfloat16_t*)buf_dyn_shmem)[((((((((int)threadIdx.x) >> 5) * 512) + ((((int)threadIdx.x) & 15) * 32)) + (i_13 * 16)) + (((((int)threadIdx.x) & 31) >> 4) * 8)) + 2048)])), __pack_half2(((bfloat16_t)acc_p[(i_13 * 8)]), ((bfloat16_t)acc_p[((i_13 * 8) + 1)])), __pack_half2(((bfloat16_t)acc_p[((i_13 * 8) + 2)]), ((bfloat16_t)acc_p[((i_13 * 8) + 3)])), __pack_half2(((bfloat16_t)acc_p[((i_13 * 8) + 4)]), ((bfloat16_t)acc_p[((i_13 * 8) + 5)])), __pack_half2(((bfloat16_t)acc_p[((i_13 * 8) + 6)]), ((bfloat16_t)acc_p[((i_13 * 8) + 7)])));
    }
    #pragma unroll
    for (int i_14 = 0; i_14 < 8; ++i_14) {
      *(float2*)(acc_dp + (i_14 * 2)) = make_float2(0.000000e+00f, 0.000000e+00f);
    }
    tl::fence_proxy_async();
    __syncthreads();
    tl::gemm_ss<64, 32, 512, 4, 1, 0, 1, 0, true>((&(((bfloat16_t*)buf_dyn_shmem)[59392])), (&(((bfloat16_t*)buf_dyn_shmem)[10240])), (&(acc_dp[0])));
    __syncthreads();
    #pragma unroll
    for (int i_15 = 0; i_15 < 8; ++i_15) {
      uint1 __3;
      float2 __4;
        float2 __5;
          float2 v__2 = *(float2*)(acc_p + (i_15 * 2));
          float2 __6;
            float2 v__3 = *(float2*)(acc_dp + (i_15 * 2));
            float2 v__4 = make_float2(delta[(i_15 & 1)], delta[(i_15 & 1)]);
            __6.x = (v__3.x-v__4.x);
            __6.y = (v__3.y-v__4.y);
          __5.x = (v__2.x*__6.x);
          __5.y = (v__2.y*__6.y);
        float2 v__5 = make_float2(4.166667e-02f, 4.166667e-02f);
        __4.x = (__5.x*v__5.x);
        __4.y = (__5.y*v__5.y);
      ((nv_bfloat162*)(&(__3.x)))->x = (bfloat16_t)(__4.x);
      ((nv_bfloat162*)(&(__3.x)))->y = (bfloat16_t)(__4.y);
      *(uint1*)(((bfloat16_t*)buf_dyn_shmem) + ((((((((((int)threadIdx.x) >> 5) * 512) + ((i_15 & 1) * 256)) + (((((int)threadIdx.x) & 31) >> 2) * 32)) + (((((((int)threadIdx.x) & 31) >> 4) + (i_15 >> 2)) & 1) * 16)) + (((((((int)threadIdx.x) & 15) >> 3) + ((i_15 & 3) >> 1)) & 1) * 8)) + ((((int)threadIdx.x) & 3) * 2)) + 4096)) = __3;
    }
    tl::fence_proxy_async();
    __syncthreads();
    tl::gemm_ss<64, 512, 32, 4, 1, 0, 0, 0, true>((&(((bfloat16_t*)buf_dyn_shmem)[4096])), (&(((bfloat16_t*)buf_dyn_shmem)[10240])), (&(acc_dq[0])));
    tl::gemm_ss<64, 64, 32, 4, 1, 0, 0, 0, true>((&(((bfloat16_t*)buf_dyn_shmem)[4096])), (&(((bfloat16_t*)buf_dyn_shmem)[0])), (&(acc_dq_tail[0])));
  }
  #pragma unroll
  for (int i_16 = 0; i_16 < 128; ++i_16) {
    uint1 __7;
    float2 v__6 = *(float2*)(acc_dq + (i_16 * 2));
    ((nv_bfloat162*)(&(__7.x)))->x = (bfloat16_t)(v__6.x);
    ((nv_bfloat162*)(&(__7.x)))->y = (bfloat16_t)(v__6.y);
    *(uint1*)(dQ + ((((((((int)blockIdx.x) * 36864) + ((((int)threadIdx.x) >> 5) * 9216)) + ((i_16 & 1) * 4608)) + (((((int)threadIdx.x) & 31) >> 2) * 576)) + ((i_16 >> 1) * 8)) + ((((int)threadIdx.x) & 3) * 2))) = __7;
  }
  #pragma unroll
  for (int i_17 = 0; i_17 < 16; ++i_17) {
    uint1 __8;
    float2 v__7 = *(float2*)(acc_dq_tail + (i_17 * 2));
    ((nv_bfloat162*)(&(__8.x)))->x = (bfloat16_t)(v__7.x);
    ((nv_bfloat162*)(&(__8.x)))->y = (bfloat16_t)(v__7.y);
    *(uint1*)(dQ + (((((((((int)blockIdx.x) * 36864) + ((((int)threadIdx.x) >> 5) * 9216)) + ((i_17 & 1) * 4608)) + (((((int)threadIdx.x) & 31) >> 2) * 576)) + ((i_17 >> 1) * 8)) + ((((int)threadIdx.x) & 3) * 2)) + 512)) = __8;
  }
}


#define ERROR_BUF_SIZE 1024
static char error_buf[ERROR_BUF_SIZE];

extern "C" const char* get_last_error() {
    return error_buf;
}

extern "C" int init() {
    error_buf[0] = '\0';
    
    cudaError_t result_main_kernel = cudaFuncSetAttribute(main_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 184320);
    if (result_main_kernel != CUDA_SUCCESS) {
        snprintf(error_buf, ERROR_BUF_SIZE, "Failed to set the allowed dynamic shared memory size to %d with error: %s", 184320, cudaGetErrorString(result_main_kernel));
        return -1;
    }

    return 0;
}

extern "C" int call(bfloat16_t* __restrict__ Q, bfloat16_t* __restrict__ KV, bfloat16_t* __restrict__ Q_d, bfloat16_t* __restrict__ K_d, bfloat16_t* __restrict__ dO, int* __restrict__ Indices, float* __restrict__ Lse, float* __restrict__ Lse_d, float* __restrict__ Delta, bfloat16_t* __restrict__ dQ, float* __restrict__ dKV, bfloat16_t* __restrict__ dQ_d, float* __restrict__ dK_d, float aux_loss_weight, cudaStream_t stream=cudaStreamDefault) {
	main_kernel<<<dim3(4096, 1, 1), dim3(128, 1, 1), 184320, stream>>>(Delta, Indices, KV, Lse, Q, Q_d, dO, dQ);
	TILELANG_CHECK_LAST_ERROR("main_kernel");

	return 0;
}


LeiWang1999 avatar Jun 26 '25 05:06 LeiWang1999