tvm
tvm copied to clipboard
[Bug] give inconsistent inference results after use ReorderTakeAfterMatmul
Actual behavior
Traceback (most recent call last):
File "/share_container/optfuzz/res/bugs/inconsis222.py", line 258, in <module>
np.testing.assert_allclose(before_outputs, after_outputs, 1e-3, 1e-3)
File "/root/miniconda3/lib/python3.12/site-packages/numpy/testing/_private/utils.py", line 1504, in assert_allclose
assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
File "/root/miniconda3/lib/python3.12/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/numpy/testing/_private/utils.py", line 718, in assert_array_compare
flagged = func_assert_same_pos(x, y, func=isnan, hasval='nan')
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/numpy/testing/_private/utils.py", line 688, in func_assert_same_pos
raise AssertionError(msg)
AssertionError:
Not equal to tolerance rtol=0.001, atol=0.001
x and y nan location mismatch:
x: array([[ 7.936000e+04, 8.032000e+04, 8.128000e+04, 8.224000e+04,
8.320000e+04, 8.416000e+04, 8.512000e+04, 8.608000e+04,
7.168000e+04, 7.252000e+04, 1.898367e+16, 7.420000e+04,...
y: array([[ 7.936000e+04, 8.032000e+04, 8.128000e+04, 8.224000e+04,
8.320000e+04, 8.416000e+04, 8.512000e+04, 8.608000e+04,
nan, 7.252000e+04, 7.336000e+04, 7.420000e+04,...
Steps to reproduce
This is a complex test case, I cannot further reduce this case due to unknown root case
import tvm
from tvm import relax
import numpy as np
import tvm
metadata = tvm.ir.load_json("""{
\"root\": 1,
\"nodes\": [
{
\"type_key\": \"\"
},
{
\"type_key\": \"Map\",
\"keys\": [
\"relax.expr.Constant\"
],
\"data\": [2]
},
{
\"type_key\": \"Array\",
\"data\": [3]
},
{
\"type_key\": \"relax.expr.Constant\",
\"attrs\": {
\"_checked_type_\": \"11\",
\"data\": \"0\",
\"span\": \"0\",
\"struct_info_\": \"4\"
}
},
{
\"type_key\": \"relax.TensorStructInfo\",
\"attrs\": {
\"dtype\": \"float32\",
\"ndim\": \"2\",
\"shape\": \"5\",
\"span\": \"0\",
\"vdevice\": \"0\"
}
},
{
\"type_key\": \"relax.expr.ShapeExpr\",
\"attrs\": {
\"_checked_type_\": \"10\",
\"span\": \"0\",
\"struct_info_\": \"9\",
\"values\": \"6\"
}
},
{
\"type_key\": \"Array\",
\"data\": [7, 8]
},
{
\"type_key\": \"IntImm\",
\"attrs\": {
\"dtype\": \"int64\",
\"span\": \"0\",
\"value\": \"16\"
}
},
{
\"type_key\": \"IntImm\",
\"attrs\": {
\"dtype\": \"int64\",
\"span\": \"0\",
\"value\": \"16\"
}
},
{
\"type_key\": \"relax.ShapeStructInfo\",
\"attrs\": {
\"ndim\": \"2\",
\"span\": \"0\",
\"values\": \"6\"
}
},
{
\"type_key\": \"relax.ShapeType\",
\"attrs\": {
\"ndim\": \"2\",
\"span\": \"0\"
}
},
{
\"type_key\": \"relax.DynTensorType\",
\"attrs\": {
\"dtype\": \"float32\",
\"ndim\": \"2\",
\"span\": \"0\"
}
}
],
\"b64ndarrays\": [
\"P6G0lvBAXt0AAAAAAAAAAAEAAAAAAAAAAgAAAAIgAQAQAAAAAAAAABAAAAAAAAAAAAQAAAAAAAAAAAAAAACAPwAAAEAAAEBAAACAQAAAoEAAAMBAAADgQAAAAEEAABBBAAAgQQAAMEEAAEBBAABQQQAAYEEAAHBBAACAQQAAiEEAAJBBAACYQQAAoEEAAKhBAACwQQAAuEEAAMBBAADIQQAA0EEAANhBAADgQQAA6EEAAPBBAAD4QQAAAEIAAARCAAAIQgAADEIAABBCAAAUQgAAGEIAABxCAAAgQgAAJEIAAChCAAAsQgAAMEIAADRCAAA4QgAAPEIAAEBCAABEQgAASEIAAExCAABQQgAAVEIAAFhCAABcQgAAYEIAAGRCAABoQgAAbEIAAHBCAAB0QgAAeEIAAHxCAACAQgAAgkIAAIRCAACGQgAAiEIAAIpCAACMQgAAjkIAAJBCAACSQgAAlEIAAJZCAACYQgAAmkIAAJxCAACeQgAAoEIAAKJCAACkQgAApkIAAKhCAACqQgAArEIAAK5CAACwQgAAskIAALRCAAC2QgAAuEIAALpCAAC8QgAAvkIAAMBCAADCQgAAxEIAAMZCAADIQgAAykIAAMxCAADOQgAA0EIAANJCAADUQgAA1kIAANhCAADaQgAA3EIAAN5CAADgQgAA4kIAAORCAADmQgAA6EIAAOpCAADsQgAA7kIAAPBCAADyQgAA9EIAAPZCAAD4QgAA+kIAAPxCAAD+QgAAAEMAAAFDAAACQwAAA0MAAARDAAAFQwAABkMAAAdDAAAIQwAACUMAAApDAAALQwAADEMAAA1DAAAOQwAAD0MAABBDAAARQwAAEkMAABNDAAAUQwAAFUMAABZDAAAXQwAAGEMAABlDAAAaQwAAG0MAABxDAAAdQwAAHkMAAB9DAAAgQwAAIUMAACJDAAAjQwAAJEMAACVDAAAmQwAAJ0MAAChDAAApQwAAKkMAACtDAAAsQwAALUMAAC5DAAAvQwAAMEMAADFDAAAyQwAAM0MAADRDAAA1QwAANkMAADdDAAA4QwAAOUMAADpDAAA7QwAAPEMAAD1DAAA+QwAAP0MAAEBDAABBQwAAQkMAAENDAABEQwAARUMAAEZDAABHQwAASEMAAElDAABKQwAAS0MAAExDAABNQwAATkMAAE9DAABQQwAAUUMAAFJDAABTQwAAVEMAAFVDAABWQwAAV0MAAFhDAABZQwAAWkMAAFtDAABcQwAAXUMAAF5DAABfQwAAYEMAAGFDAABiQwAAY0MAAGRDAABlQwAAZkMAAGdDAABoQwAAaUMAAGpDAABrQwAAbEMAAG1DAABuQwAAb0MAAHBDAABxQwAAckMAAHNDAAB0QwAAdUMAAHZDAAB3QwAAeEMAAHlDAAB6QwAAe0MAAHxDAAB9QwAAfkMAAH9D\"
],
\"attrs\": {\"tvm_version\": \"0.17.dev0\"}
}""")
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def add(A: T.Buffer((T.int64(16), T.int64(16)), "float32"), B: T.Buffer((T.int64(16), T.int64(16)), "float32"), T_add: T.Buffer((T.int64(16), T.int64(16)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(16), T.int64(16)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[v_ax0, v_ax1]
@T.prim_func(private=True)
def cast(gv: T.Buffer((T.int64(16), T.int64(16)), "float32"), compute: T.Buffer((T.int64(16), T.int64(16)), "int64")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i0, i1 in T.grid(T.int64(16), T.int64(16)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(gv[v_i0, v_i1])
T.writes(compute[v_i0, v_i1])
compute[v_i0, v_i1] = T.Cast("int64", gv[v_i0, v_i1])
@T.prim_func(private=True)
def matmul(x: T.Buffer((T.int64(1), T.int64(16)), "float32"), weight: T.Buffer((T.int64(16), T.int64(32)), "float32"), matmul: T.Buffer((T.int64(1), T.int64(32)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i0, i1, k in T.grid(T.int64(1), T.int64(32), T.int64(16)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(x[v_i0, v_k], weight[v_k, v_i1])
T.writes(matmul[v_i0, v_i1])
with T.init():
matmul[v_i0, v_i1] = T.float32(0)
matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + x[v_i0, v_k] * weight[v_k, v_i1]
@T.prim_func(private=True)
def reshape(gv: T.Buffer((T.int64(16), T.int64(16)), "float32"), T_reshape: T.Buffer((T.int64(256),), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(256)):
with T.block("T_reshape"):
v_ax0 = T.axis.spatial(T.int64(256), ax0)
T.reads(gv[v_ax0 % T.int64(256) // T.int64(16), v_ax0 % T.int64(16)])
T.writes(T_reshape[v_ax0])
T_reshape[v_ax0] = gv[v_ax0 % T.int64(256) // T.int64(16), v_ax0 % T.int64(16)]
@T.prim_func(private=True)
def reshape1(temp: T.Buffer((T.int64(16),), "float32"), T_reshape: T.Buffer((T.int64(1), T.int64(16)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(1), T.int64(16)):
with T.block("T_reshape"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(temp[v_ax1 % T.int64(16)])
T.writes(T_reshape[v_ax0, v_ax1])
T_reshape[v_ax0, v_ax1] = temp[v_ax1 % T.int64(16)]
@T.prim_func(private=True)
def reshape2(gv: T.Buffer((T.int64(16), T.int64(16)), "int64"), T_reshape: T.Buffer((T.int64(256),), "int64")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(256)):
with T.block("T_reshape"):
v_ax0 = T.axis.spatial(T.int64(256), ax0)
T.reads(gv[v_ax0 % T.int64(256) // T.int64(16), v_ax0 % T.int64(16)])
T.writes(T_reshape[v_ax0])
T_reshape[v_ax0] = gv[v_ax0 % T.int64(256) // T.int64(16), v_ax0 % T.int64(16)]
@T.prim_func(private=True)
def reshape3(temp: T.Buffer((T.int64(32),), "int64"), T_reshape: T.Buffer((T.int64(32),), "int64")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(32)):
with T.block("T_reshape"):
v_ax0 = T.axis.spatial(T.int64(32), ax0)
T.reads(temp[v_ax0 % T.int64(32)])
T.writes(T_reshape[v_ax0])
T_reshape[v_ax0] = temp[v_ax0 % T.int64(32)]
@T.prim_func(private=True)
def strided_slice(tensor_1dim: T.Buffer((T.int64(256),), "float32"), T_strided_slice_with_axes: T.Buffer((T.int64(16),), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(16)):
with T.block("T_strided_slice_with_axes"):
v_ax0 = T.axis.spatial(T.int64(16), ax0)
T.reads(tensor_1dim[v_ax0])
T.writes(T_strided_slice_with_axes[v_ax0])
T_strided_slice_with_axes[v_ax0] = tensor_1dim[v_ax0]
@T.prim_func(private=True)
def strided_slice1(tensor_1dim: T.Buffer((T.int64(256),), "int64"), T_strided_slice_with_axes: T.Buffer((T.int64(32),), "int64")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(32)):
with T.block("T_strided_slice_with_axes"):
v_ax0 = T.axis.spatial(T.int64(32), ax0)
T.reads(tensor_1dim[v_ax0])
T.writes(T_strided_slice_with_axes[v_ax0])
T_strided_slice_with_axes[v_ax0] = tensor_1dim[v_ax0]
@T.prim_func(private=True)
def take(var_weight_table: T.handle, routing_table: T.Buffer((T.int64(32),), "int64"), T_take: T.Buffer((T.int64(16), T.int64(32)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
weight_table_size = T.int64()
weight_table = T.match_buffer(var_weight_table, (T.int64(16), weight_table_size))
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(16), T.int64(32)):
with T.block("T_take"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(weight_table[v_ax0, routing_table[v_ax1]], routing_table[v_ax1])
T.writes(T_take[v_ax0, v_ax1])
T_take[v_ax0, v_ax1] = weight_table[v_ax0, routing_table[v_ax1]]
@R.function
def main_7(x: R.Tensor((1, 16), dtype="float32"), weight_table: R.Tensor((16, "weight_table_size"), dtype="float32"), routing_table: R.Tensor((32,), dtype="int64")) -> R.Tensor((1, 32), dtype="float32"):
weight_table_size = T.int64()
cls = Module
with R.dataflow():
weight = R.call_tir(cls.take, (weight_table, routing_table), out_sinfo=R.Tensor((16, 32), dtype="float32"))
out = R.call_tir(cls.matmul, (x, weight), out_sinfo=R.Tensor((1, 32), dtype="float32"))
R.output(out)
return out
@R.function
def main() -> R.Tensor((1, 32), dtype="float32"):
cls = Module
gv = R.call_tir(cls.add, (metadata["relax.expr.Constant"][0], metadata["relax.expr.Constant"][0]), out_sinfo=R.Tensor((16, 16), dtype="float32"))
tensor_1dim = R.call_tir(cls.reshape, (gv,), out_sinfo=R.Tensor((256,), dtype="float32"))
temp = R.call_tir(cls.strided_slice, (tensor_1dim,), out_sinfo=R.Tensor((16,), dtype="float32"))
para0 = R.call_tir(cls.reshape1, (temp,), out_sinfo=R.Tensor((1, 16), dtype="float32"))
para1: R.Tensor((16, 16), dtype="float32") = gv
gv_1 = R.call_tir(cls.cast, (gv,), out_sinfo=R.Tensor((16, 16), dtype="int64"))
tensor_1dim_1 = R.call_tir(cls.reshape2, (gv_1,), out_sinfo=R.Tensor((256,), dtype="int64"))
temp_1 = R.call_tir(cls.strided_slice1, (tensor_1dim_1,), out_sinfo=R.Tensor((32,), dtype="int64"))
para2 = R.call_tir(cls.reshape3, (temp_1,), out_sinfo=R.Tensor((32,), dtype="int64"))
res: R.Tensor((1, 32), dtype="float32") = cls.main_7(para0, para1, para2)
return res
def compile_mod(mod, func_name, target, *inputs):
ex = relax.build(mod, target='llvm')
vm = relax.VirtualMachine(ex, tvm.cpu())
mod_outputs = vm[f'{func_name}'](*inputs)
mod_outputs = mod_outputs.numpy()
return mod_outputs
mod = Module
before_outputs = compile_mod(mod, 'main', 'llvm')
mod = relax.transform.FoldConstant()(mod)
mod = relax.transform.ReorderTakeAfterMatmul()(mod)
after_outputs = compile_mod(mod, 'main', 'llvm')
np.testing.assert_allclose(before_outputs, after_outputs, 1e-3, 1e-3)
CC @Lunderberg @junrushao
The inconsistent results occur even if the FoldConstant and ReorderTakeAfterMatmul lines are removed. This looks like there's some out-of-bounds access in the PrimFuncs. If I replace the TIR functions with their Relax equivalents, I get something as follows, which shows the same inconsistent outputs:
class Module:
@R.function
def main():
metadata_constant = R.reshape(R.arange(256), [16, 16]).astype("float32")
weight_table: R.Tensor([16, 16], "float32") = metadata_constant + metadata_constant
x = R.strided_slice(
R.reshape(weight_table, [256]),
axes=[0],
begin=[0],
end=[16],
)
indices = weight_table.astype("int64")
routing_table = R.strided_slice(
R.reshape(indices, [256]),
axes=[0],
begin=[0],
end=[32],
)
weight = R.take(weight_table, routing_table, axis=1)
out = R.matmul(x, weight)
return out
The initial constant from the metadata has values within [0,256). After adding it to itself and taking the first 32 indices, the values range from [0,64). However, the weight table is of shape [16,16], and these indices are used to access axis 1. This out-of-bounds access results in inconsistent outputs.