tvm icon indicating copy to clipboard operation
tvm copied to clipboard

[TIR] Enhance and fix tensorize schedule for some case

Open LeiWang1999 opened this issue 1 year ago • 2 comments

To optimize i4_to_f16 decoding, we can use some advanced hardware instructions to do fast type conversion to alleviate the cost of decoding, we can do that by tensorize in tvm.

To tensorize decoding, this pr extends the call component of ir_comparator, which is necessary because the decode block comprises call expressions.

Moreover, currently comparator do simplification on the lhs expr, however, the tensor intrin descs are not simplified, which will be inconsistent and will fail at comparation, see this pr: https://github.com/apache/tvm/pull/14108.

For example, we provide a test case for this situation:

def test_tensorize_arith_simplification():
    # fmt: off
    @T.prim_func
    def decode_i4s_to_int32_to_f16():
        B_decode_local = T.alloc_buffer((16384, 16384), "float16", scope="local")
        B_local = T.alloc_buffer((16384, 2048), "int32", scope="local")
        for ax0_0 in T.thread_binding(8192, thread="blockIdx.x"):
            for ax0_1 in T.thread_binding(2, thread="threadIdx.y"):
                for ax1_0 in range(32):
                    for ax1_1 in T.thread_binding(64, thread="threadIdx.x"):
                        for ax0, ax1 in T.grid(1, 8):
                            with T.block("B_decode_local"):
                                v0 = T.axis.spatial(16384, ax0_0 * 2 + ax0_1 + ax0)
                                v1 = T.axis.spatial(16384, ax1_0 * 512 + ax1_1 * 8 + ax1)
                                T.reads(B_local[v0, v1 // 8])
                                T.writes(B_decode_local[v0, v1])
                                B_decode_local[v0, v1] = T.Cast("float16", T.shift_right(T.shift_left(T.bitwise_and(T.shift_right(B_local[v0, v1 // 8], v1 % 8 * 4), 15), 28), 28))

The desc should be simplified from [v1 // 8] and [v1 % 8] to [0], [v1] to match the simplified lhs expr.

To do simplification for tensor intrin's desc, we warp and reuse tir::transform::simplify to support simplification for single stmt.

LeiWang1999 avatar Feb 13 '24 13:02 LeiWang1999

cc @Hzfengsy and @Lunderberg , looks like pr #13299 provides a stmt_simplify declaration but do not provide an implementation.

LeiWang1999 avatar Feb 13 '24 14:02 LeiWang1999

cc @vinx13

tqchen avatar Feb 13 '24 14:02 tqchen

@vinx13 @spectrometerHBH do you mind take a look at this PR?

tqchen avatar Mar 04 '24 14:03 tqchen