tvm
tvm copied to clipboard
[Codegen, CUDA] Add FP8 Tensor Core Codegen
Major changes of this pull request:
- Change the fp8-related test
requires_cuda_compute_versionfrom 9 to 8.9 (since sm_89 ada architecture also supports fp8 tensor cores, which is the platform I have tested on). - Improve fp8 vector load/store capabilities; previously, TVM only supported float8x4/2/1 load, but this PR introduces support for float8x8/16 load.
- Refactor the interface of
get_mma_intrin_groupandget_mma_intrinfunctions, as the prior implementation assumed that input A and input B were of the same datatype. However, fp8 tensor cores can process combinations like e5m2e5m2, e5m2e4m3, e4m3e4m3, or e4m3e5m2. Note: This change may affect code in MLC that utilizesget_mma_intrin_group. - Implement support for fp8 mma code generation and associated tests.
Check out the correctness:
import tvm
from tvm import te
import numpy as np
import tvm.testing
from tvm.script import tir as T
import os
from tvm.tir.tensor_intrin.cuda import (
get_mma_intrin_group,
shared_16x16_to_ldmatrix_32x8_layout,
shared_32x16_to_ldmatrix_32x16_layout,
shared_16x32_to_ldmatrix_32x16_layout,
)
M = 1024
N = 1024
K = 1024
BM = 64
BN = 64
BK = 64
warp_size = 32
block_row_warps = 2
block_col_warps = 4
indtype = "e4m3_float8"
out_dtype = "float32"
# indtype = "int8"
# out_dtype = "int32"
intrin_group = get_mma_intrin_group(
"shared",
"global",
in_dtype=indtype,
out_dtype=out_dtype,
trans_a=False,
trans_b=True,
not_use_mma_store_intrinic=False,
)
@tvm.script.ir_module
class MyModule:
@T.prim_func
def main(a: T.handle, b: T.handle, c: T.handle):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
A = T.match_buffer(a, [M, K], dtype=indtype)
B = T.match_buffer(b, [N, K], dtype=indtype)
C = T.match_buffer(c, [M, N], dtype=out_dtype)
for i, j, k in T.grid(M, N, K):
with T.block("B"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = T.int32(0)
C[vi, vj] = C[vi, vj] + \
A[vi, vk].astype(out_dtype) * B[vj, vk].astype(out_dtype)
ir_module = MyModule
print(ir_module)
sch = tvm.tir.Schedule(ir_module, debug_mask="all")
block_b = sch.get_block("B")
(i, j, k) = sch.get_loops(block_b)
by, i = sch.split(i, factors=[None, BM])
bx, j = sch.split(j, factors=[None, BN])
bk, k = sch.split(k, factors=[None, BK])
sch.reorder(by, bx, bk, i, j, k)
sch.bind(bx, "blockIdx.x")
sch.bind(by, "blockIdx.y")
block_b_tz, block_b_inner_i = sch.split(
i, factors=[block_row_warps, None])
block_b_ty, block_b_inner_j = sch.split(
j, factors=[block_col_warps, None])
sch.reorder(block_b_tz, block_b_ty, bk, block_b_inner_i, block_b_inner_j, k)
sch.bind(block_b_tz, "threadIdx.z")
sch.bind(block_b_ty, "threadIdx.y")
# schdule the shared memory
def fetch_to_shared(block, idx):
block_read = sch.cache_read(block, idx, "shared")
sch.compute_at(block_read, bk)
vector_size = 16
fused = sch.fuse(*sch.get_loops(block_read)[-2:])
_, f_1, f_2, f_3 = sch.split(
fused, factors=[None, block_col_warps, warp_size, vector_size])
sch.bind(f_2, "threadIdx.x")
sch.bind(f_1, "threadIdx.y")
sch.vectorize(f_3)
offset = 0
sch.storage_align(block_read, 0, axis=-2, factor=32, offset=offset)
# schedule A
fetch_to_shared(block_b, 0)
# schedule B
fetch_to_shared(block_b, 1)
# blockize for mma tensorize
mma_m = 16
mma_n = 16
mma_k = 32
block_b_inner_i, block_b_inner_i_tc = sch.split(
block_b_inner_i, factors=[None, mma_m])
block_b_inner_j, block_b_inner_j_tc = sch.split(
block_b_inner_j, factors=[None, mma_n])
k, k_tc = sch.split(k, factors=[None, mma_k])
sch.reorder(block_b_inner_i, block_b_inner_j,
k, block_b_inner_i_tc, block_b_inner_j_tc, k_tc)
A_warp = sch.cache_read(block_b, 0, "warp")
B_warp = sch.cache_read(block_b, 1, "warp")
sch.compute_at(A_warp, k)
sch.compute_at(B_warp, k)
C_warp = sch.cache_write(block_b, 0, "warp")
sch.reverse_compute_at(C_warp, block_b_ty)
ii, jj = sch.get_loops(C_warp)[-2:]
io, ii = sch.split(ii, factors=[None, mma_m])
jo, ji = sch.split(jj, factors=[None, mma_n])
sch.reorder(io, jo, ii, ji)
def tile_wmma_fragment(block_read, height, width):
i, j = sch.get_loops(block_read)[-2:]
return i
loop_a = tile_wmma_fragment(A_warp, mma_m, mma_k)
loop_b = tile_wmma_fragment(B_warp, mma_n, mma_k)
block_init_c = sch.decompose_reduction(
block_b, bk)
def index_map_A(i, j):
return (
i // 16,
j // 32,
*shared_16x32_to_ldmatrix_32x16_layout(i % 16, j % 32),
)
def index_map_B(i, j):
return (
i // 32,
j // 16,
*shared_32x16_to_ldmatrix_32x16_layout(i % 32, j % 16),
)
def index_map_C(i, j):
return (
i // 16,
j // 16,
*shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16),
)
sch.transform_layout(A_warp, ("write", 0), index_map_A)
sch.transform_layout(B_warp, ("write", 0), index_map_A)
sch.transform_layout(C_warp, ("read", 0), index_map_C)
sch.tensorize(loop_a, intrin_group["load_a"])
sch.tensorize(loop_b, intrin_group["load_b"])
# _test_block = sch.get_block("")
sch.tensorize(block_b_inner_i_tc, intrin_group["compute"])
sch.tensorize(sch.get_loops(block_init_c)[-2], intrin_group["init"])
sch.tensorize(sch.get_loops(C_warp)[-2], intrin_group["store"])
ctx = tvm.cuda(0)
cuda_mod = tvm.build(sch.mod, target="cuda")
def map_numpy_type(intype):
typemap = {
'e4m3_float8': 'float8_e4m3fn',
'e5m2_float8': 'float8_e5m2',
}
if intype in typemap:
return typemap[intype]
else:
return intype
numpytype_a = map_numpy_type(indtype)
numpytype_b = map_numpy_type(indtype)
numpytype_c = map_numpy_type(out_dtype)
a = np.random.uniform(low=-5, high=5, size=(M*K)).reshape((M, K)).astype(numpytype_a)
b = np.random.uniform(low=-5, high=5, size=(N*K)).reshape((K, N)).astype(numpytype_b)
out = np.matmul(a, b.T)
print("numpy_simulated:", out)
cuda_a = tvm.nd.array(a, ctx)
cuda_b = tvm.nd.array(b, ctx)
cuda_c = tvm.nd.array(np.zeros((M, N)).astype(numpytype_c), ctx)
cuda_mod(cuda_a, cuda_b, cuda_c)
print("codegen:", cuda_c)
num_flops = 2 * M * K * N
num_runs = 1
timer_cuda_mod = cuda_mod.time_evaluator(
cuda_mod.entry_name, ctx, number=num_runs)
t = timer_cuda_mod(cuda_a, cuda_b, cuda_c).mean
GFLOPS = num_flops / (t * 1e3) / 1e6
print("average time cost of %d runs = %g ms, %g GFLOPS." %
(num_runs, t * 1e3, GFLOPS))
expected output:
numpy_simulated: [[-410.33817 -30.429443 -470.51312 ... 64.58632 -381.49658
14.920105]
[ 56.357788 744.9746 -29.630783 ... -44.779022 298.5943
-24.109558]
[ 77.765305 -426.8894 286.35736 ... 10.655792 -129.63507
232.30026 ]
...
[ 39.094635 -47.508118 -225.59912 ... 775.10614 -109.92264
268.50952 ]
[-813.8422 111.21069 -316.5697 ... 455.90875 -37.09839
478.28406 ]
[ 122.78345 148.104 340.1291 ... -304.5721 -115.578735
-639.9563 ]]
codegen: [[-410.28125 -30.441406 -470.09375 ... 64.66406 -381.5
14.8203125]
[ 56.367188 744.8125 -29.597656 ... -44.695312 298.625
-24.148438 ]
[ 77.65625 -426.71875 286.3125 ... 10.746094 -129.6875
232.34375 ]
...
[ 39.191406 -47.539062 -225.57812 ... 774.9375 -109.875
268.46875 ]
[-813.625 111.109375 -316.46875 ... 455.96875 -37.08203
478.0625 ]
[ 122.75 148.10938 339.84375 ... -304.5 -115.546875
-639.8125 ]]
Please CC @yzh119
cc @vinx13
I just removed the modification for the signature in get_mma_intrin_group, so this pr no longer affects items which uses this function. I believe this pr is now ready for review. @Hzfengsy