tilelang icon indicating copy to clipboard operation
tilelang copied to clipboard

Int8 Convolution

Open jonpryai opened this issue 6 months ago • 7 comments

I tried to modify the convolution to use int8 inputs and int32 accumulator. This seems like it should work because the underlying GEMM supports these data types. However, everything goes wrong, presumably because of the swizzle. With:

InternalError: Check failed: (StructuralEqual()(layout, layout_map[buffer])) is false: Get different layout for kernel_shared

I'm a little confused as to how the swizzle is supposed to work here. My understanding is that sometimes narrower data types sometimes need to use 4 byte types in shared memory to avoid bank conflicts but I haven't figured out how this could even be done in TileLang.

Any advice?

jonpryai avatar May 27 '25 21:05 jonpryai

Hi @jonpryai, thanks for your reporting. would you mind provide the entire code to help up reproduce?

LeiWang1999 avatar May 28 '25 04:05 LeiWang1999

# Licensed under the MIT License.
import torch
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
import itertools
import argparse
from tilelang.carver.template import ConvTemplate
from tilelang.carver.arch import CUDA
from tilelang.carver.roller.rasterization import NoRasterization


def check_hopper():
    # if not torch.cuda.is_available():
    #     return None
    # props = torch.cuda.get_device_properties(0)
    # compute_capability = props.major, props.minor
    # return compute_capability == (9, 0)
    return False


def get_configs(N, C, H, W, F, K, S, D, P, with_roller=False):
    if with_roller:
        arch = CUDA("cuda")
        topk = 10
        carve_template = ConvTemplate(
            N=N,
            C=C,
            H=H,
            W=W,
            F=F,
            K=K,
            S=S,
            D=D,
            P=P,
            in_dtype="float16",
            out_dtype="float16",
            accum_dtype="float",
        ).with_arch(arch)

        func = carve_template.equivalent_function()
        assert func is not None, "Function is None"
        roller_hints = carve_template.recommend_hints(topk=topk)
        if roller_hints is None:
            raise ValueError("No Roller Hints Found for TensorCore Scheduling")
        configs = []
        for hint in roller_hints:
            config = {}
            block_m, block_n = hint.block
            warp_m, warp_n = hint.warp
            block_rows, block_cols = block_m // warp_m, block_n // warp_n
            config["block_M"] = block_m
            config["block_N"] = block_n
            config["block_K"] = hint.rstep[0]
            config["num_stages"] = hint.pipeline_stage
            config["thread_num"] = block_rows * block_cols * 32
            config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization
            configs.append(config)
        for config in configs:
            print(config)
    else:
        block_M = [64, 128, 256]
        block_N = [64, 128, 256]
        block_K = [32, 64]
        num_stages = [0, 1, 2, 3]
        threads = [128, 256]
        _configs = list(itertools.product(block_M, block_N, block_K, num_stages, threads))

        configs = [{
            'block_M': c[0],
            'block_N': c[1],
            'block_K': c[2],
            'num_stages': c[3],
            'thread_num': c[4]
        } for c in _configs]
    return configs


def ref_program(stride, padding, dilation):

    def main(A, B):
        A = A.permute(0, 3, 1, 2)  # N, H, W, C -> N, C, H, W
        B = B.permute(3, 2, 0, 1)  # H, W, C, F -> F, C, H, W
        C = torch.conv2d(A, B, stride=stride, padding=padding, dilation=dilation)
        C = C.permute(0, 2, 3, 1)  # N, C, H, W -> N, H, W, C
        return C

    return main


def get_best_config(N, C, H, W, F, K, S, D, P, with_roller):
    KH, KW = K, K
    OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
    OW = (W + 2 * P - D * (K - 1) - 1) // S + 1

    dtype = "int8"
    accum_dtype = "int32"
    is_hopper = check_hopper()

    def kernel(
        block_M=None,
        block_N=None,
        block_K=None,
        num_stages=None,
        thread_num=None,
    ):

        @T.prim_func
        def main(
                data: T.Tensor((N, H, W, C), dtype),
                kernel: T.Tensor((KH, KW, C, F), dtype),
                out: T.Tensor((N, OH, OW, F), dtype),
        ):
            with T.Kernel(
                    T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M),
                    threads=thread_num) as (bx, by):
                data_shared = T.alloc_shared((block_M, block_K), dtype)
                kernel_shared = T.alloc_shared((block_K, block_N), dtype)
                out_local = T.alloc_fragment((block_M, block_N), accum_dtype)
                out_shared = T.alloc_shared((block_M, block_N), dtype)

                kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data)
                out_flat = T.Tensor((N * OH * OW, F), dtype, out.data)

                T.annotate_layout({
                    out_shared: tilelang.layout.make_swizzled_layout(out_shared),
                    data_shared: tilelang.layout.make_swizzled_layout(data_shared),
                    kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared),
                })

                T.clear(out_local)
                for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
                    if is_hopper:
                        T.c2d_im2col(data, data_shared, by, k_iter, KH, S, D, P)
                    else:
                        for i, j in T.Parallel(block_M, block_K):
                            k = k_iter * block_K + j
                            m = by * block_M + i
                            access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P
                            access_w = m % OW * S + k // C % KW * D - P
                            in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and
                                        (access_w < W))
                            data_shared[i, j] = T.if_then_else(
                                in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0)
                    T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared)
                    T.gemm(data_shared, kernel_shared, out_local)

                T.copy(out_local, out_shared)
                T.copy(out_shared, out_flat[by * block_M, bx * block_N])

        return main

    autotuner = AutoTuner.from_kernel(
        kernel=kernel, configs=get_configs(N, C, H, W, F, K, S, D, P,
                                           with_roller)).set_compile_args(
                                               out_idx=[2],
                                               supply_type=tilelang.TensorSupplyType.Integer,
                                               ref_prog=ref_program(S, P, D),
                                               skip_check=False,
                                               target="auto",
                                           )
    return autotuner.run(warmup=10, rep=10)


def convolution(N,
                C,
                H,
                W,
                F,
                K,
                S,
                D,
                P,
                block_M,
                block_N,
                block_K,
                num_stages,
                threads,
                dtype="float16",
                accum_dtype="float"):
    KH, KW = K, K
    OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
    OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
    dtype = "float16"
    accum_dtype = "float"
    is_hopper = check_hopper()

    @T.prim_func
    def main(
            data: T.Tensor((N, H, W, C), dtype),
            kernel: T.Tensor((KH, KW, C, F), dtype),
            out: T.Tensor((N, OH, OW, F), dtype),
    ):
        with T.Kernel(
                T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M),
                threads=threads) as (bx, by):
            data_shared = T.alloc_shared((block_M, block_K), dtype)
            kernel_shared = T.alloc_shared((block_K, block_N), dtype)
            out_local = T.alloc_fragment((block_M, block_N), accum_dtype)
            out_shared = T.alloc_shared((block_M, block_N), dtype)

            kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data)
            out_flat = T.Tensor((N * OH * OW, F), dtype, out.data)

            T.annotate_layout({
                out_shared: tilelang.layout.make_swizzled_layout(out_shared),
                data_shared: tilelang.layout.make_swizzled_layout(data_shared),
                kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared),
            })

            T.clear(out_local)
            for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
                if is_hopper:
                    T.c2d_im2col(data, data_shared, by, k_iter, KH, S, D, P)
                else:
                    for i, j in T.Parallel(block_M, block_K):
                        k = k_iter * block_K + j
                        m = by * block_M + i
                        access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P
                        access_w = m % OW * S + k // C % KW * D - P
                        in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and
                                    (access_w < W))
                        data_shared[i, j] = T.if_then_else(
                            in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0)
                T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared)
                T.gemm(data_shared, kernel_shared, out_local)

            T.copy(out_local, out_shared)
            T.copy(out_shared, out_flat[by * block_M, bx * block_N])

    return main


def main(argv=None):
    parser = argparse.ArgumentParser()
    parser.add_argument('--n', type=int, default=128, help='n')
    parser.add_argument('--c', type=int, default=128, help='c')
    parser.add_argument('--h', type=int, default=64, help='h')
    parser.add_argument('--w', type=int, default=64, help='w')
    parser.add_argument('--f', type=int, default=128, help='f')
    parser.add_argument('--k', type=int, default=3, help='k')
    parser.add_argument('--s', type=int, default=1, help='s')
    parser.add_argument('--d', type=int, default=1, help='d')
    parser.add_argument('--p', type=int, default=1, help='p')
    parser.add_argument(
        "--use_autotune",
        action="store_true",
        default=True,
        help="Whether to use autotune for matmul configs")
    parser.add_argument(
        "--with_roller",
        action="store_true",
        default=True,
        help="Whether to enable BitBLAS roller for search space")

    args = parser.parse_args(argv)
    N, C, H, W, F, K, S, D, P = args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p
    a = torch.randn(N, H, W, C).cuda().to(torch.int8)
    b = torch.randn(K, K, C, F).cuda().to(torch.int8)
    use_autotune = args.use_autotune
    with_roller = args.with_roller
    if use_autotune:
        result = get_best_config(N, C, H, W, F, K, S, D, P, with_roller)
        print(f"best latency {result.latency}")
        kernel = result.kernel
    else:
        kernel = tilelang.compile(
            convolution(N, C, H, W, F, K, S, D, P, 64, 128, 32, 3, 256), out_idx=[2])

    out_c = kernel(a, b)
    ref_c = ref_program(S, P, D)(a, b)
    torch.testing.assert_close(out_c, ref_c, rtol=1e-2, atol=1e-2)


if __name__ == "__main__":
    main()```

jonpry avatar May 28 '25 05:05 jonpry

@jonpry thanks, but this code works for me if we disabled autotune with the lates nightly build.

LeiWang1999 avatar May 28 '25 07:05 LeiWang1999

likely the part of this example (autotune especially) is a bit outdated, we will remove this to make the example clean, and add an separate autotune example for convolutions.

LeiWang1999 avatar May 28 '25 07:05 LeiWang1999

I just pulled the latest source and applied this patch:

index cb32d22..5ce5549 100644
--- a/examples/convolution/example_convolution.py
+++ b/examples/convolution/example_convolution.py
@@ -46,8 +46,8 @@ def convolution(N,
     KH, KW = K, K
     OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
     OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
-    dtype = "float16"
-    accum_dtype = "float"
+    dtype = "int8"
+    accum_dtype = "int32"
     is_hopper = check_hopper()
 
     @T.prim_func
@@ -110,8 +110,8 @@ def main(argv=None):
 
     args = parser.parse_args(argv)
     N, C, H, W, F, K, S, D, P = args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p
-    a = torch.randn(N, H, W, C).cuda().half()
-    b = torch.randn(K, K, C, F).cuda().half()
+    a = torch.randn(N, H, W, C).cuda().to(torch.int8)
+    b = torch.randn(K, K, C, F).cuda().to(torch.int8)
 
     block_m = 64
     block_n = 128

And get the same InternalError: Check failed: (StructuralEqual()(layout, layout_map[buffer])) is false: Get different layout for kernel_shared As before.

jonpryai avatar May 28 '25 20:05 jonpryai

Could it be that you are using Hopper and I am running it on Ampere?

jonpryai avatar Jun 04 '25 20:06 jonpryai

I found that, at least on my ampere machine, the int8 T.gemm must have transposed B to work, such as T.gemm(data_shared, kernel_shared, out_local, transpose_B=True)

This causes at least some configurations to be working. I also found that there are many cases where 16 lane data types end up being used which are not support. Strangely these kernels are functional if one adds a longlong8 type and removes the associated checks for lanes <= 8.

Still the kernel performance is not very good. More than 2x the runtime of nvidia int8 convolutions.

jonpryai avatar Sep 02 '25 22:09 jonpryai