FP8 JIT compilation fails with TVM InternalError
minimal reproduce code
import tilelang
import tilelang.language as T
FP8 = "float8_e4m3"
BF16 = "bfloat16"
@tilelang.jit
def test_kernel(N, in_dtype=BF16, out_dtype=FP8):
M = T.symbolic("M")
blk_m = 128
group_size = 128
@T.prim_func
def test_kernel_(X: T.Tensor[(M, N), in_dtype], Y: T.Tensor[(M, N), out_dtype]):
with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (pid_m, pid_n):
x_shared = T.alloc_shared((blk_m, group_size), in_dtype)
T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared)
T.copy(x_shared, Y[pid_m * blk_m, pid_n * group_size])
return test_kernel_
test_kernel(128, out_dtype=BF16) # can compile
test_kernel(128, out_dtype=FP8) # raise error
The FP8 variant logs repeated dtype mismatch fallbacks and then aborts with tvm.error.InternalError: Check failed: (!type_name.empty()) is false while invoking target.build.tilelang_cuda_without_compile.
I’m not sure why this happens.
The compilation error is caused by the absence of a direct bf16-to-fp8 cast function in cutlass's float8.h. To resolve this issue, an intermediate conversion is required: bf16 should first be converted to fp32, and then from fp32 to fp8. here is my code, which compiles successfully:
@tilelang.jit
def test_kernel(N, in_dtype=BF16, out_dtype=FP8):
M = T.symbolic("M")
blk_m = 128
group_size = 128
@T.prim_func
def test_kernel_(X: T.Tensor[(M, N), in_dtype], Y: T.Tensor[(M, N), out_dtype]):
with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (pid_m, pid_n):
x_shared = T.alloc_shared((blk_m, group_size), in_dtype)
x_shared_fp32 = T.alloc_shared((blk_m, group_size), "float32")
T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared)
T.copy(x_shared, x_shared_fp32)
T.copy(x_shared_fp32, Y[pid_m * blk_m, pid_n * group_size])
return test_kernel_
@Cunxiao2002 would you mind help overwrite the cutlass implementation to address this issue?
ok, I will take a look.