tvm icon indicating copy to clipboard operation
tvm copied to clipboard

[Codegen, CUDA] Add FP8 Tensor Core Codegen

Open LeiWang1999 opened this issue 1 year ago • 2 comments

Major changes of this pull request:

  • Change the fp8-related test requires_cuda_compute_version from 9 to 8.9 (since sm_89 ada architecture also supports fp8 tensor cores, which is the platform I have tested on).
  • Improve fp8 vector load/store capabilities; previously, TVM only supported float8x4/2/1 load, but this PR introduces support for float8x8/16 load.
  • Refactor the interface of get_mma_intrin_group and get_mma_intrin functions, as the prior implementation assumed that input A and input B were of the same datatype. However, fp8 tensor cores can process combinations like e5m2e5m2, e5m2e4m3, e4m3e4m3, or e4m3e5m2. Note: This change may affect code in MLC that utilizes get_mma_intrin_group.
  • Implement support for fp8 mma code generation and associated tests.

Check out the correctness:

import tvm
from tvm import te
import numpy as np
import tvm.testing
from tvm.script import tir as T
import os
from tvm.tir.tensor_intrin.cuda import (
    get_mma_intrin_group,
    shared_16x16_to_ldmatrix_32x8_layout,
    shared_32x16_to_ldmatrix_32x16_layout,
    shared_16x32_to_ldmatrix_32x16_layout,
)

M = 1024
N = 1024
K = 1024

BM = 64
BN = 64
BK = 64
warp_size = 32
block_row_warps = 2
block_col_warps = 4

indtype = "e4m3_float8"
out_dtype = "float32"
# indtype = "int8"
# out_dtype = "int32"
intrin_group = get_mma_intrin_group(
    "shared",
    "global",
    in_dtype=indtype,
    out_dtype=out_dtype,
    trans_a=False,
    trans_b=True,
    not_use_mma_store_intrinic=False,
)

@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def main(a: T.handle, b: T.handle, c: T.handle):
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        A = T.match_buffer(a, [M, K], dtype=indtype)
        B = T.match_buffer(b, [N, K], dtype=indtype)
        C = T.match_buffer(c, [M, N], dtype=out_dtype)

        for i, j, k in T.grid(M, N, K):
            with T.block("B"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    C[vi, vj] = T.int32(0)
                C[vi, vj] = C[vi, vj] + \
                    A[vi, vk].astype(out_dtype) * B[vj, vk].astype(out_dtype)


ir_module = MyModule
print(ir_module)
sch = tvm.tir.Schedule(ir_module, debug_mask="all")

block_b = sch.get_block("B")

(i, j, k) = sch.get_loops(block_b)
by, i = sch.split(i, factors=[None, BM])
bx, j = sch.split(j, factors=[None, BN])
bk, k = sch.split(k, factors=[None, BK])

sch.reorder(by, bx, bk, i, j, k)

sch.bind(bx, "blockIdx.x")
sch.bind(by, "blockIdx.y")


block_b_tz, block_b_inner_i = sch.split(
    i, factors=[block_row_warps, None])

block_b_ty, block_b_inner_j = sch.split(
    j, factors=[block_col_warps, None])

sch.reorder(block_b_tz, block_b_ty, bk, block_b_inner_i, block_b_inner_j, k)

sch.bind(block_b_tz, "threadIdx.z")
sch.bind(block_b_ty, "threadIdx.y")

# schdule the shared memory

def fetch_to_shared(block, idx):
    block_read = sch.cache_read(block, idx, "shared")
    sch.compute_at(block_read, bk)
    vector_size = 16
    fused = sch.fuse(*sch.get_loops(block_read)[-2:])
    _, f_1, f_2, f_3 = sch.split(
        fused, factors=[None, block_col_warps, warp_size, vector_size])
    sch.bind(f_2, "threadIdx.x")
    sch.bind(f_1, "threadIdx.y")
    sch.vectorize(f_3)
    offset = 0
    sch.storage_align(block_read, 0, axis=-2, factor=32, offset=offset)

# schedule A
fetch_to_shared(block_b, 0)
# schedule B
fetch_to_shared(block_b, 1)


# blockize for mma tensorize

mma_m = 16
mma_n = 16
mma_k = 32

block_b_inner_i, block_b_inner_i_tc = sch.split(
    block_b_inner_i, factors=[None, mma_m])
block_b_inner_j, block_b_inner_j_tc = sch.split(
    block_b_inner_j, factors=[None, mma_n])
k, k_tc = sch.split(k, factors=[None, mma_k])

sch.reorder(block_b_inner_i, block_b_inner_j,
            k, block_b_inner_i_tc, block_b_inner_j_tc, k_tc)

A_warp = sch.cache_read(block_b, 0, "warp")
B_warp = sch.cache_read(block_b, 1, "warp")
sch.compute_at(A_warp, k)
sch.compute_at(B_warp, k)
C_warp = sch.cache_write(block_b, 0, "warp")
sch.reverse_compute_at(C_warp, block_b_ty)

ii, jj = sch.get_loops(C_warp)[-2:]
io, ii = sch.split(ii, factors=[None, mma_m])
jo, ji = sch.split(jj, factors=[None, mma_n])
sch.reorder(io, jo, ii, ji)


def tile_wmma_fragment(block_read, height, width):
    i, j = sch.get_loops(block_read)[-2:]
    return i

loop_a = tile_wmma_fragment(A_warp, mma_m, mma_k)

loop_b = tile_wmma_fragment(B_warp, mma_n, mma_k)

block_init_c = sch.decompose_reduction(
    block_b, bk)

def index_map_A(i, j):
    return (
        i // 16,
        j // 32,
        *shared_16x32_to_ldmatrix_32x16_layout(i % 16, j % 32),
    )

def index_map_B(i, j):
    return (
        i // 32,
        j // 16,
        *shared_32x16_to_ldmatrix_32x16_layout(i % 32, j % 16),
    )

def index_map_C(i, j):
    return (
        i // 16,
        j // 16,
        *shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16),
    )


sch.transform_layout(A_warp, ("write", 0), index_map_A)
sch.transform_layout(B_warp, ("write", 0), index_map_A)
sch.transform_layout(C_warp, ("read", 0), index_map_C)


sch.tensorize(loop_a, intrin_group["load_a"])
sch.tensorize(loop_b, intrin_group["load_b"])

# _test_block = sch.get_block("")
sch.tensorize(block_b_inner_i_tc, intrin_group["compute"])

sch.tensorize(sch.get_loops(block_init_c)[-2], intrin_group["init"])
sch.tensorize(sch.get_loops(C_warp)[-2], intrin_group["store"])


ctx = tvm.cuda(0)
cuda_mod = tvm.build(sch.mod, target="cuda")

def map_numpy_type(intype):
    
    typemap = {
        'e4m3_float8': 'float8_e4m3fn',
        'e5m2_float8': 'float8_e5m2',
    }
    if intype in typemap:
        return typemap[intype]
    else:
        return intype

numpytype_a = map_numpy_type(indtype)
numpytype_b = map_numpy_type(indtype)
numpytype_c = map_numpy_type(out_dtype)
a = np.random.uniform(low=-5, high=5, size=(M*K)).reshape((M, K)).astype(numpytype_a)
b = np.random.uniform(low=-5, high=5, size=(N*K)).reshape((K, N)).astype(numpytype_b)
out = np.matmul(a, b.T)

print("numpy_simulated:", out)

cuda_a = tvm.nd.array(a, ctx)
cuda_b = tvm.nd.array(b, ctx)
cuda_c = tvm.nd.array(np.zeros((M, N)).astype(numpytype_c), ctx)
cuda_mod(cuda_a, cuda_b, cuda_c)

print("codegen:", cuda_c)
num_flops = 2 * M * K * N
num_runs = 1
timer_cuda_mod = cuda_mod.time_evaluator(
    cuda_mod.entry_name, ctx, number=num_runs)

t = timer_cuda_mod(cuda_a, cuda_b, cuda_c).mean

GFLOPS = num_flops / (t * 1e3) / 1e6
print("average time cost of %d runs = %g ms, %g GFLOPS." %
      (num_runs, t * 1e3, GFLOPS))

expected output:

numpy_simulated: [[-410.33817   -30.429443 -470.51312  ...   64.58632  -381.49658
    14.920105]
 [  56.357788  744.9746    -29.630783 ...  -44.779022  298.5943
   -24.109558]
 [  77.765305 -426.8894    286.35736  ...   10.655792 -129.63507
   232.30026 ]
 ...
 [  39.094635  -47.508118 -225.59912  ...  775.10614  -109.92264
   268.50952 ]
 [-813.8422    111.21069  -316.5697   ...  455.90875   -37.09839
   478.28406 ]
 [ 122.78345   148.104     340.1291   ... -304.5721   -115.578735
  -639.9563  ]]
codegen: [[-410.28125    -30.441406  -470.09375   ...   64.66406   -381.5
    14.8203125]
 [  56.367188   744.8125     -29.597656  ...  -44.695312   298.625
   -24.148438 ]
 [  77.65625   -426.71875    286.3125    ...   10.746094  -129.6875
   232.34375  ]
 ...
 [  39.191406   -47.539062  -225.57812   ...  774.9375    -109.875
   268.46875  ]
 [-813.625      111.109375  -316.46875   ...  455.96875    -37.08203
   478.0625   ]
 [ 122.75       148.10938    339.84375   ... -304.5       -115.546875
  -639.8125   ]]

Please CC @yzh119

LeiWang1999 avatar Apr 29 '24 05:04 LeiWang1999

cc @vinx13

Hzfengsy avatar Apr 29 '24 06:04 Hzfengsy

I just removed the modification for the signature in get_mma_intrin_group, so this pr no longer affects items which uses this function. I believe this pr is now ready for review. @Hzfengsy

LeiWang1999 avatar Jul 03 '24 06:07 LeiWang1999