relax
relax copied to clipboard
[MetaSchedule][Hexagon] conv2d produces different results after tuning
The following PrimFunc produces different results after tuning on hexagon.
@T.prim_func
def conv2d(lv1: T.Buffer[(1, 230, 230, 3), "float16"], param_0: T.Buffer[(7, 7, 3, 64), "float16"], conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float16"]):
# function attr dict
T.func_attr({"global_symbol": "conv2d", "tir.noalias": True})
# body
# with T.block("root")
for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 112, 112, 64, 7, 7, 3):
with T.block("conv2d_nhwc"):
nn, yy, xx, ff, ry, rx, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6])
T.reads(lv1[nn, yy * 2 + ry, xx * 2 + rx, rc], param_0[ry, rx, rc, ff])
T.writes(conv2d_nhwc[nn, yy, xx, ff])
with T.init():
conv2d_nhwc[nn, yy, xx, ff] = T.float16(0)
conv2d_nhwc[nn, yy, xx, ff] = conv2d_nhwc[nn, yy, xx, ff] + lv1[nn, yy * 2 + ry, xx * 2 + rx, rc] * param_0[ry, rx, rc, ff]
Post tuning the PrimFunc is transformed to:
@T.prim_func
def conv2d(lv1: T.Buffer[(1, 230, 230, 3), "float16"], param_0: T.Buffer[(7, 7, 3, 64), "float16"], conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float16"]):
# function attr dict
T.func_attr({"tir.noalias": True, "global_symbol": "conv2d"})
# body
# with T.block("root")
conv2d_nhwc_global = T.alloc_buffer([1, 112, 112, 64], dtype="float16")
for i0_0_i1_0_i2_0_fused in T.parallel(196, annotations={"pragma_auto_unroll_max_step":T.int64(512), "pragma_unroll_explicit":T.int64(1)}):
for i3_0 in T.serial(1):
for i0_1_init, i1_1_init, i2_1_init, i3_1_init, i0_2_init, i1_2_init, i2_2_init in T.grid(1, 2, 16, 1, 1, 2, 1):
for i3_2_fused_init in T.vectorized(64):
with T.block("conv2d_nhwc_init"):
nn = T.axis.spatial(1, i0_1_init + i0_2_init)
yy = T.axis.spatial(112, i0_0_i1_0_i2_0_fused // 7 * 4 + i1_1_init * 2 + i1_2_init)
xx = T.axis.spatial(112, i2_2_init + i0_0_i1_0_i2_0_fused % 7 * 16 + i2_1_init)
ff = T.axis.spatial(64, i3_0 * 64 + i3_1_init * 64 + i3_2_fused_init)
T.reads()
T.writes(conv2d_nhwc_global[nn, yy, xx, ff])
T.block_attr({"meta_schedule.tiling_structure":"SRSRS"})
conv2d_nhwc_global[nn, yy, xx, ff] = T.float16(0)
for i4_0, i5_0, i6_0 in T.grid(1, 7, 1):
for i0_1, i1_1, i2_1, i3_1, i4_1, i5_1, i6_1, i0_2, i1_2, i2_2 in T.grid(1, 2, 16, 1, 7, 1, 3, 1, 2, 1):
for i3_2_fused in T.vectorized(64):
with T.block("conv2d_nhwc_update"):
nn = T.axis.spatial(1, i0_1 + i0_2)
yy = T.axis.spatial(112, i0_0_i1_0_i2_0_fused // 7 * 4 + i1_1 * 2 + i1_2)
xx = T.axis.spatial(112, i2_2 + i0_0_i1_0_i2_0_fused % 7 * 16 + i2_1)
ff = T.axis.spatial(64, i3_0 * 64 + i3_1 * 64 + i3_2_fused)
ry = T.axis.reduce(7, i4_0 * 7 + i4_1)
rx = T.axis.reduce(7, i5_0 + i5_1)
rc = T.axis.reduce(3, i6_0 * 3 + i6_1)
T.reads(conv2d_nhwc_global[nn, yy, xx, ff], lv1[nn, yy * 2 + ry, xx * 2 + rx, rc], param_0[ry, rx, rc, ff])
T.writes(conv2d_nhwc_global[nn, yy, xx, ff])
T.block_attr({"meta_schedule.tiling_structure":"SRSRS"})
conv2d_nhwc_global[nn, yy, xx, ff] = conv2d_nhwc_global[nn, yy, xx, ff] + lv1[nn, yy * 2 + ry, xx * 2 + rx, rc] * param_0[ry, rx, rc, ff]
for ax0, ax1, ax2 in T.grid(1, 4, 16):
for ax3_fused in T.vectorized(64):
with T.block("conv2d_nhwc_global"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial(112, i0_0_i1_0_i2_0_fused // 7 * 4 + ax1)
v2 = T.axis.spatial(112, i0_0_i1_0_i2_0_fused % 7 * 16 + ax2)
v3 = T.axis.spatial(64, ax3_fused)
T.reads(conv2d_nhwc_global[v0, v1, v2, v3])
T.writes(conv2d_nhwc[v0, v1, v2, v3])
conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_global[v0, v1, v2, v3]
The two PrimFuncs produce different results on hexagon hardware. This needs to be investigated.
Thanks @psrivas2 for reporting the issue!
Two questions that could help us know more about the context:
- Is it hexagon specific, i.e. if we tune conv2d on cpu and gpu, will this incorrect results also happen?
- Is it only conv2d, i.e. if we tune other kernels on hexagon, will the before/after tuned kernels give different results?
First, it is hexagon specific. On CPU the tuned kernel output is same as untuned output. Second, I have only observed this behavior for this specific kernel. For example, after fusion, resnet has around 31 PrimFuncs. Out of those 31, only 1 PrimFunc which had the above block as one of the fused operations was producing different results than untuned PrimFuncs.
~In addition to that, this is definitely some incorrect transformation of untuned PrimFunc, as the two PrimFuncs shown above give different results even on CPU.~
I think I have narrowed it down to the reordering of loops.
On Hexagon the following two modules which differ only in the order of loops i3 & i4 produce different numeric results. The max difference in values is 0.5 and the mean difference is 0.0708. This is only happening for fp16 dtype.
@tvm.script.ir_module
class TuningBug:
@T.prim_func
def conv2d(lv1: T.Buffer[(1, 230, 230, 3), "float16"], param_0: T.Buffer[(7, 7, 3, 64), "float16"], conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float16"]):
# function attr dict
T.func_attr({"tir.noalias": True, "global_symbol": "conv2d"})
# body
# with T.block("root")
for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 112, 112, 64, 7, 7, 3):
with T.block("conv2d_nhwc"):
nn, yy, xx, ff, ry, rx, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6])
T.reads(lv1[nn, yy * 2 + ry, xx * 2 + rx, rc], param_0[ry, rx, rc, ff])
T.writes(conv2d_nhwc[nn, yy, xx, ff])
with T.init():
conv2d_nhwc[nn, yy, xx, ff] = T.float16(0)
conv2d_nhwc[nn, yy, xx, ff] = (conv2d_nhwc[nn, yy, xx, ff] + lv1[nn, yy * 2 + ry, xx * 2 + rx, rc] * param_0[ry, rx, rc, ff])
@R.function
def main(lv1: R.Tensor[(1, 230, 230, 3), "float16"], param_0: R.Tensor[(T.int64(7), T.int64(7), T.int64(3), T.int64(64)), "float16"]):
with R.dataflow():
gv = R.call_tir(conv2d, (lv1, param_0), (1, 112, 112, 64), dtype="float16")
R.output(gv)
return gv
Reorder loops i3 & i4
sch = tvm.tir.Schedule(mod)
b0 = sch.get_block("conv2d_nhwc", func_name="conv2d")
i0, i1, i2, i3, i4, i5, i6 = sch.get_loops(b0)
sch.reorder(i4, i3)
the modified module looks like below
@tvm.script.ir_module
class TuningBug:
@T.prim_func
def conv2d(lv1: T.Buffer[(1, 230, 230, 3), "float16"], param_0: T.Buffer[(7, 7, 3, 64), "float16"], conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float16"]):
# function attr dict
T.func_attr({"tir.noalias": True, "global_symbol": "conv2d"})
# body
# with T.block("root")
for i0, i1, i2, i4, i3, i5, i6 in T.grid(1, 112, 112, 7, 64, 7, 3):
with T.block("conv2d_nhwc"):
nn, yy, xx, ff, ry, rx, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6])
T.reads(lv1[nn, yy * 2 + ry, xx * 2 + rx, rc], param_0[ry, rx, rc, ff])
T.writes(conv2d_nhwc[nn, yy, xx, ff])
with T.init():
conv2d_nhwc[nn, yy, xx, ff] = T.float16(0)
conv2d_nhwc[nn, yy, xx, ff] = (conv2d_nhwc[nn, yy, xx, ff] + lv1[nn, yy * 2 + ry, xx * 2 + rx, rc] * param_0[ry, rx, rc, ff])
@R.function
def main(lv1: R.Tensor[(1, 230, 230, 3), "float16"], param_0: R.Tensor[(T.int64(7), T.int64(7), T.int64(3), T.int64(64)), "float16"]):
with R.dataflow():
gv = R.call_tir(conv2d, (lv1, param_0), (1, 112, 112, 64), dtype="float16")
R.output(gv)
return gv