relax icon indicating copy to clipboard operation
relax copied to clipboard

[MetaSchedule][Hexagon] conv2d produces different results after tuning

Open psrivas2 opened this issue 2 years ago • 4 comments

The following PrimFunc produces different results after tuning on hexagon.

@T.prim_func
def conv2d(lv1: T.Buffer[(1, 230, 230, 3), "float16"], param_0: T.Buffer[(7, 7, 3, 64), "float16"], conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float16"]):
    # function attr dict
    T.func_attr({"global_symbol": "conv2d", "tir.noalias": True})
    # body
    # with T.block("root")
    for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 112, 112, 64, 7, 7, 3):
        with T.block("conv2d_nhwc"):
            nn, yy, xx, ff, ry, rx, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6])
            T.reads(lv1[nn, yy * 2 + ry, xx * 2 + rx, rc], param_0[ry, rx, rc, ff])
            T.writes(conv2d_nhwc[nn, yy, xx, ff])
            with T.init():
                conv2d_nhwc[nn, yy, xx, ff] = T.float16(0)
            conv2d_nhwc[nn, yy, xx, ff] = conv2d_nhwc[nn, yy, xx, ff] + lv1[nn, yy * 2 + ry, xx * 2 + rx, rc] * param_0[ry, rx, rc, ff]

Post tuning the PrimFunc is transformed to:

@T.prim_func
def conv2d(lv1: T.Buffer[(1, 230, 230, 3), "float16"], param_0: T.Buffer[(7, 7, 3, 64), "float16"], conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float16"]):
    # function attr dict
    T.func_attr({"tir.noalias": True, "global_symbol": "conv2d"})
    # body
    # with T.block("root")
    conv2d_nhwc_global = T.alloc_buffer([1, 112, 112, 64], dtype="float16")
    for i0_0_i1_0_i2_0_fused in T.parallel(196, annotations={"pragma_auto_unroll_max_step":T.int64(512), "pragma_unroll_explicit":T.int64(1)}):
        for i3_0 in T.serial(1):
            for i0_1_init, i1_1_init, i2_1_init, i3_1_init, i0_2_init, i1_2_init, i2_2_init in T.grid(1, 2, 16, 1, 1, 2, 1):
                for i3_2_fused_init in T.vectorized(64):
                    with T.block("conv2d_nhwc_init"):
                        nn = T.axis.spatial(1, i0_1_init + i0_2_init)
                        yy = T.axis.spatial(112, i0_0_i1_0_i2_0_fused // 7 * 4 + i1_1_init * 2 + i1_2_init)
                        xx = T.axis.spatial(112, i2_2_init + i0_0_i1_0_i2_0_fused % 7 * 16 + i2_1_init)
                        ff = T.axis.spatial(64, i3_0 * 64 + i3_1_init * 64 + i3_2_fused_init)
                        T.reads()
                        T.writes(conv2d_nhwc_global[nn, yy, xx, ff])
                        T.block_attr({"meta_schedule.tiling_structure":"SRSRS"})
                        conv2d_nhwc_global[nn, yy, xx, ff] = T.float16(0)
            for i4_0, i5_0, i6_0 in T.grid(1, 7, 1):
                for i0_1, i1_1, i2_1, i3_1, i4_1, i5_1, i6_1, i0_2, i1_2, i2_2 in T.grid(1, 2, 16, 1, 7, 1, 3, 1, 2, 1):
                    for i3_2_fused in T.vectorized(64):
                        with T.block("conv2d_nhwc_update"):
                            nn = T.axis.spatial(1, i0_1 + i0_2)
                            yy = T.axis.spatial(112, i0_0_i1_0_i2_0_fused // 7 * 4 + i1_1 * 2 + i1_2)
                            xx = T.axis.spatial(112, i2_2 + i0_0_i1_0_i2_0_fused % 7 * 16 + i2_1)
                            ff = T.axis.spatial(64, i3_0 * 64 + i3_1 * 64 + i3_2_fused)
                            ry = T.axis.reduce(7, i4_0 * 7 + i4_1)
                            rx = T.axis.reduce(7, i5_0 + i5_1)
                            rc = T.axis.reduce(3, i6_0 * 3 + i6_1)
                            T.reads(conv2d_nhwc_global[nn, yy, xx, ff], lv1[nn, yy * 2 + ry, xx * 2 + rx, rc], param_0[ry, rx, rc, ff])
                            T.writes(conv2d_nhwc_global[nn, yy, xx, ff])
                            T.block_attr({"meta_schedule.tiling_structure":"SRSRS"})
                            conv2d_nhwc_global[nn, yy, xx, ff] = conv2d_nhwc_global[nn, yy, xx, ff] + lv1[nn, yy * 2 + ry, xx * 2 + rx, rc] * param_0[ry, rx, rc, ff]
                for ax0, ax1, ax2 in T.grid(1, 4, 16):
                    for ax3_fused in T.vectorized(64):
                        with T.block("conv2d_nhwc_global"):
                            v0 = T.axis.spatial(1, ax0)
                            v1 = T.axis.spatial(112, i0_0_i1_0_i2_0_fused // 7 * 4 + ax1)
                            v2 = T.axis.spatial(112, i0_0_i1_0_i2_0_fused % 7 * 16 + ax2)
                            v3 = T.axis.spatial(64, ax3_fused)
                            T.reads(conv2d_nhwc_global[v0, v1, v2, v3])
                            T.writes(conv2d_nhwc[v0, v1, v2, v3])
                            conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_global[v0, v1, v2, v3]

The two PrimFuncs produce different results on hexagon hardware. This needs to be investigated.

psrivas2 avatar Dec 02 '22 21:12 psrivas2

Thanks @psrivas2 for reporting the issue!

Two questions that could help us know more about the context:

  • Is it hexagon specific, i.e. if we tune conv2d on cpu and gpu, will this incorrect results also happen?
  • Is it only conv2d, i.e. if we tune other kernels on hexagon, will the before/after tuned kernels give different results?

YuchenJin avatar Dec 02 '22 23:12 YuchenJin

First, it is hexagon specific. On CPU the tuned kernel output is same as untuned output. Second, I have only observed this behavior for this specific kernel. For example, after fusion, resnet has around 31 PrimFuncs. Out of those 31, only 1 PrimFunc which had the above block as one of the fused operations was producing different results than untuned PrimFuncs.

psrivas2 avatar Dec 03 '22 02:12 psrivas2

~In addition to that, this is definitely some incorrect transformation of untuned PrimFunc, as the two PrimFuncs shown above give different results even on CPU.~

psrivas2 avatar Dec 03 '22 03:12 psrivas2

I think I have narrowed it down to the reordering of loops.

On Hexagon the following two modules which differ only in the order of loops i3 & i4 produce different numeric results. The max difference in values is 0.5 and the mean difference is 0.0708. This is only happening for fp16 dtype.

@tvm.script.ir_module
class TuningBug:
    @T.prim_func
    def conv2d(lv1: T.Buffer[(1, 230, 230, 3), "float16"], param_0: T.Buffer[(7, 7, 3, 64), "float16"], conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float16"]):
        # function attr dict
        T.func_attr({"tir.noalias": True, "global_symbol": "conv2d"})
        # body
        # with T.block("root")
        for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 112, 112, 64, 7, 7, 3):
            with T.block("conv2d_nhwc"):
                nn, yy, xx, ff, ry, rx, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6])
                T.reads(lv1[nn, yy * 2 + ry, xx * 2 + rx, rc], param_0[ry, rx, rc, ff])
                T.writes(conv2d_nhwc[nn, yy, xx, ff])
                with T.init():
                    conv2d_nhwc[nn, yy, xx, ff] = T.float16(0)
                conv2d_nhwc[nn, yy, xx, ff] = (conv2d_nhwc[nn, yy, xx, ff] + lv1[nn, yy * 2 + ry, xx * 2 + rx, rc] * param_0[ry, rx, rc, ff])

    @R.function
    def main(lv1: R.Tensor[(1, 230, 230, 3), "float16"], param_0: R.Tensor[(T.int64(7), T.int64(7), T.int64(3), T.int64(64)), "float16"]):
        with R.dataflow():
            gv = R.call_tir(conv2d, (lv1, param_0), (1, 112, 112, 64), dtype="float16")
            R.output(gv)
        return gv

Reorder loops i3 & i4

sch = tvm.tir.Schedule(mod)
b0 = sch.get_block("conv2d_nhwc", func_name="conv2d")
i0, i1, i2, i3, i4, i5, i6 = sch.get_loops(b0)
sch.reorder(i4, i3)

the modified module looks like below

@tvm.script.ir_module
class TuningBug:
    @T.prim_func
    def conv2d(lv1: T.Buffer[(1, 230, 230, 3), "float16"], param_0: T.Buffer[(7, 7, 3, 64), "float16"], conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float16"]):
        # function attr dict
        T.func_attr({"tir.noalias": True, "global_symbol": "conv2d"})
        # body
        # with T.block("root")
        for i0, i1, i2, i4, i3, i5, i6 in T.grid(1, 112, 112, 7, 64, 7, 3):
            with T.block("conv2d_nhwc"):
                nn, yy, xx, ff, ry, rx, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6])
                T.reads(lv1[nn, yy * 2 + ry, xx * 2 + rx, rc], param_0[ry, rx, rc, ff])
                T.writes(conv2d_nhwc[nn, yy, xx, ff])
                with T.init():
                    conv2d_nhwc[nn, yy, xx, ff] = T.float16(0)
                conv2d_nhwc[nn, yy, xx, ff] = (conv2d_nhwc[nn, yy, xx, ff] + lv1[nn, yy * 2 + ry, xx * 2 + rx, rc] * param_0[ry, rx, rc, ff])

    @R.function
    def main(lv1: R.Tensor[(1, 230, 230, 3), "float16"], param_0: R.Tensor[(T.int64(7), T.int64(7), T.int64(3), T.int64(64)), "float16"]):
        with R.dataflow():
            gv = R.call_tir(conv2d, (lv1, param_0), (1, 112, 112, 64), dtype="float16")
            R.output(gv)
        return gv

psrivas2 avatar Dec 13 '22 21:12 psrivas2