tilelang icon indicating copy to clipboard operation
tilelang copied to clipboard

FP8 JIT compilation fails with TVM InternalError

Open L1aoXingyu opened this issue 2 months ago • 3 comments

minimal reproduce code

import tilelang
import tilelang.language as T

FP8 = "float8_e4m3"
BF16 = "bfloat16"


@tilelang.jit
def test_kernel(N, in_dtype=BF16, out_dtype=FP8):
    M = T.symbolic("M")
    blk_m = 128
    group_size = 128

    @T.prim_func
    def test_kernel_(X: T.Tensor[(M, N), in_dtype], Y: T.Tensor[(M, N), out_dtype]):
        with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (pid_m, pid_n):
            x_shared = T.alloc_shared((blk_m, group_size), in_dtype)
            T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared)
            T.copy(x_shared, Y[pid_m * blk_m, pid_n * group_size])

    return test_kernel_

test_kernel(128, out_dtype=BF16) # can compile
test_kernel(128, out_dtype=FP8) # raise error

The FP8 variant logs repeated dtype mismatch fallbacks and then aborts with tvm.error.InternalError: Check failed: (!type_name.empty()) is false while invoking target.build.tilelang_cuda_without_compile.

I’m not sure why this happens.

L1aoXingyu avatar Oct 16 '25 02:10 L1aoXingyu

The compilation error is caused by the absence of a direct bf16-to-fp8 cast function in cutlass's float8.h. To resolve this issue, an intermediate conversion is required: bf16 should first be converted to fp32, and then from fp32 to fp8. here is my code, which compiles successfully:

@tilelang.jit
def test_kernel(N, in_dtype=BF16, out_dtype=FP8):
    M = T.symbolic("M")
    blk_m = 128
    group_size = 128

    @T.prim_func
    def test_kernel_(X: T.Tensor[(M, N), in_dtype], Y: T.Tensor[(M, N), out_dtype]):
        with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (pid_m, pid_n):
            x_shared = T.alloc_shared((blk_m, group_size), in_dtype)
            x_shared_fp32 = T.alloc_shared((blk_m, group_size), "float32")
            T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared)
            T.copy(x_shared, x_shared_fp32)
            T.copy(x_shared_fp32, Y[pid_m * blk_m, pid_n * group_size])

    return test_kernel_

Cunxiao2002 avatar Oct 16 '25 02:10 Cunxiao2002

@Cunxiao2002 would you mind help overwrite the cutlass implementation to address this issue?

LeiWang1999 avatar Oct 19 '25 04:10 LeiWang1999

ok, I will take a look.

Cunxiao2002 avatar Oct 19 '25 06:10 Cunxiao2002