tvm
tvm copied to clipboard
[TIR] Enhance and fix tensorize schedule for some case
To optimize i4_to_f16 decoding, we can use some advanced hardware instructions to do fast type conversion to alleviate the cost of decoding, we can do that by tensorize in tvm.
To tensorize decoding, this pr extends the call component of ir_comparator, which is necessary because the decode block comprises call expressions.
Moreover, currently comparator do simplification on the lhs expr, however, the tensor intrin descs are not simplified, which will be inconsistent and will fail at comparation, see this pr: https://github.com/apache/tvm/pull/14108.
For example, we provide a test case for this situation:
def test_tensorize_arith_simplification():
# fmt: off
@T.prim_func
def decode_i4s_to_int32_to_f16():
B_decode_local = T.alloc_buffer((16384, 16384), "float16", scope="local")
B_local = T.alloc_buffer((16384, 2048), "int32", scope="local")
for ax0_0 in T.thread_binding(8192, thread="blockIdx.x"):
for ax0_1 in T.thread_binding(2, thread="threadIdx.y"):
for ax1_0 in range(32):
for ax1_1 in T.thread_binding(64, thread="threadIdx.x"):
for ax0, ax1 in T.grid(1, 8):
with T.block("B_decode_local"):
v0 = T.axis.spatial(16384, ax0_0 * 2 + ax0_1 + ax0)
v1 = T.axis.spatial(16384, ax1_0 * 512 + ax1_1 * 8 + ax1)
T.reads(B_local[v0, v1 // 8])
T.writes(B_decode_local[v0, v1])
B_decode_local[v0, v1] = T.Cast("float16", T.shift_right(T.shift_left(T.bitwise_and(T.shift_right(B_local[v0, v1 // 8], v1 % 8 * 4), 15), 28), 28))
The desc should be simplified from [v1 // 8] and [v1 % 8] to [0], [v1] to match the simplified lhs expr.
To do simplification for tensor intrin's desc, we warp and reuse tir::transform::simplify to support simplification for single stmt.
cc @Hzfengsy and @Lunderberg , looks like pr #13299 provides a stmt_simplify declaration but do not provide an implementation.
cc @vinx13
@vinx13 @spectrometerHBH do you mind take a look at this PR?