tvm icon indicating copy to clipboard operation
tvm copied to clipboard

[Bug] give inconsistent inference results after use ReorderTakeAfterMatmul

Open Cookiee235 opened this issue 1 year ago • 1 comments

Actual behavior

Traceback (most recent call last):
  File "/share_container/optfuzz/res/bugs/inconsis222.py", line 258, in <module>
    np.testing.assert_allclose(before_outputs, after_outputs, 1e-3, 1e-3)
  File "/root/miniconda3/lib/python3.12/site-packages/numpy/testing/_private/utils.py", line 1504, in assert_allclose
    assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
  File "/root/miniconda3/lib/python3.12/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/numpy/testing/_private/utils.py", line 718, in assert_array_compare
    flagged = func_assert_same_pos(x, y, func=isnan, hasval='nan')
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/numpy/testing/_private/utils.py", line 688, in func_assert_same_pos
    raise AssertionError(msg)
AssertionError: 
Not equal to tolerance rtol=0.001, atol=0.001

x and y nan location mismatch:
 x: array([[ 7.936000e+04,  8.032000e+04,  8.128000e+04,  8.224000e+04,
         8.320000e+04,  8.416000e+04,  8.512000e+04,  8.608000e+04,
         7.168000e+04,  7.252000e+04,  1.898367e+16,  7.420000e+04,...
 y: array([[ 7.936000e+04,  8.032000e+04,  8.128000e+04,  8.224000e+04,
         8.320000e+04,  8.416000e+04,  8.512000e+04,  8.608000e+04,
                  nan,  7.252000e+04,  7.336000e+04,  7.420000e+04,...

Steps to reproduce

This is a complex test case, I cannot further reduce this case due to unknown root case
import tvm
from tvm import relax
import numpy as np
import tvm
metadata = tvm.ir.load_json("""{
  \"root\": 1, 
  \"nodes\": [
    {
      \"type_key\": \"\"
    }, 
    {
      \"type_key\": \"Map\", 
      \"keys\": [
        \"relax.expr.Constant\"
      ], 
      \"data\": [2]
    }, 
    {
      \"type_key\": \"Array\", 
      \"data\": [3]
    }, 
    {
      \"type_key\": \"relax.expr.Constant\", 
      \"attrs\": {
        \"_checked_type_\": \"11\", 
        \"data\": \"0\", 
        \"span\": \"0\", 
        \"struct_info_\": \"4\"
      }
    }, 
    {
      \"type_key\": \"relax.TensorStructInfo\", 
      \"attrs\": {
        \"dtype\": \"float32\", 
        \"ndim\": \"2\", 
        \"shape\": \"5\", 
        \"span\": \"0\", 
        \"vdevice\": \"0\"
      }
    }, 
    {
      \"type_key\": \"relax.expr.ShapeExpr\", 
      \"attrs\": {
        \"_checked_type_\": \"10\", 
        \"span\": \"0\", 
        \"struct_info_\": \"9\", 
        \"values\": \"6\"
      }
    }, 
    {
      \"type_key\": \"Array\", 
      \"data\": [7, 8]
    }, 
    {
      \"type_key\": \"IntImm\", 
      \"attrs\": {
        \"dtype\": \"int64\", 
        \"span\": \"0\", 
        \"value\": \"16\"
      }
    }, 
    {
      \"type_key\": \"IntImm\", 
      \"attrs\": {
        \"dtype\": \"int64\", 
        \"span\": \"0\", 
        \"value\": \"16\"
      }
    }, 
    {
      \"type_key\": \"relax.ShapeStructInfo\", 
      \"attrs\": {
        \"ndim\": \"2\", 
        \"span\": \"0\", 
        \"values\": \"6\"
      }
    }, 
    {
      \"type_key\": \"relax.ShapeType\", 
      \"attrs\": {
        \"ndim\": \"2\", 
        \"span\": \"0\"
      }
    }, 
    {
      \"type_key\": \"relax.DynTensorType\", 
      \"attrs\": {
        \"dtype\": \"float32\", 
        \"ndim\": \"2\", 
        \"span\": \"0\"
      }
    }
  ], 
  \"b64ndarrays\": [
    \"P6G0lvBAXt0AAAAAAAAAAAEAAAAAAAAAAgAAAAIgAQAQAAAAAAAAABAAAAAAAAAAAAQAAAAAAAAAAAAAAACAPwAAAEAAAEBAAACAQAAAoEAAAMBAAADgQAAAAEEAABBBAAAgQQAAMEEAAEBBAABQQQAAYEEAAHBBAACAQQAAiEEAAJBBAACYQQAAoEEAAKhBAACwQQAAuEEAAMBBAADIQQAA0EEAANhBAADgQQAA6EEAAPBBAAD4QQAAAEIAAARCAAAIQgAADEIAABBCAAAUQgAAGEIAABxCAAAgQgAAJEIAAChCAAAsQgAAMEIAADRCAAA4QgAAPEIAAEBCAABEQgAASEIAAExCAABQQgAAVEIAAFhCAABcQgAAYEIAAGRCAABoQgAAbEIAAHBCAAB0QgAAeEIAAHxCAACAQgAAgkIAAIRCAACGQgAAiEIAAIpCAACMQgAAjkIAAJBCAACSQgAAlEIAAJZCAACYQgAAmkIAAJxCAACeQgAAoEIAAKJCAACkQgAApkIAAKhCAACqQgAArEIAAK5CAACwQgAAskIAALRCAAC2QgAAuEIAALpCAAC8QgAAvkIAAMBCAADCQgAAxEIAAMZCAADIQgAAykIAAMxCAADOQgAA0EIAANJCAADUQgAA1kIAANhCAADaQgAA3EIAAN5CAADgQgAA4kIAAORCAADmQgAA6EIAAOpCAADsQgAA7kIAAPBCAADyQgAA9EIAAPZCAAD4QgAA+kIAAPxCAAD+QgAAAEMAAAFDAAACQwAAA0MAAARDAAAFQwAABkMAAAdDAAAIQwAACUMAAApDAAALQwAADEMAAA1DAAAOQwAAD0MAABBDAAARQwAAEkMAABNDAAAUQwAAFUMAABZDAAAXQwAAGEMAABlDAAAaQwAAG0MAABxDAAAdQwAAHkMAAB9DAAAgQwAAIUMAACJDAAAjQwAAJEMAACVDAAAmQwAAJ0MAAChDAAApQwAAKkMAACtDAAAsQwAALUMAAC5DAAAvQwAAMEMAADFDAAAyQwAAM0MAADRDAAA1QwAANkMAADdDAAA4QwAAOUMAADpDAAA7QwAAPEMAAD1DAAA+QwAAP0MAAEBDAABBQwAAQkMAAENDAABEQwAARUMAAEZDAABHQwAASEMAAElDAABKQwAAS0MAAExDAABNQwAATkMAAE9DAABQQwAAUUMAAFJDAABTQwAAVEMAAFVDAABWQwAAV0MAAFhDAABZQwAAWkMAAFtDAABcQwAAXUMAAF5DAABfQwAAYEMAAGFDAABiQwAAY0MAAGRDAABlQwAAZkMAAGdDAABoQwAAaUMAAGpDAABrQwAAbEMAAG1DAABuQwAAb0MAAHBDAABxQwAAckMAAHNDAAB0QwAAdUMAAHZDAAB3QwAAeEMAAHlDAAB6QwAAe0MAAHxDAAB9QwAAfkMAAH9D\"
  ], 
  \"attrs\": {\"tvm_version\": \"0.17.dev0\"}
}""")
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def add(A: T.Buffer((T.int64(16), T.int64(16)), "float32"), B: T.Buffer((T.int64(16), T.int64(16)), "float32"), T_add: T.Buffer((T.int64(16), T.int64(16)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(16), T.int64(16)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1])
                T.writes(T_add[v_ax0, v_ax1])
                T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[v_ax0, v_ax1]

    @T.prim_func(private=True)
    def cast(gv: T.Buffer((T.int64(16), T.int64(16)), "float32"), compute: T.Buffer((T.int64(16), T.int64(16)), "int64")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i0, i1 in T.grid(T.int64(16), T.int64(16)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(gv[v_i0, v_i1])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("int64", gv[v_i0, v_i1])

    @T.prim_func(private=True)
    def matmul(x: T.Buffer((T.int64(1), T.int64(16)), "float32"), weight: T.Buffer((T.int64(16), T.int64(32)), "float32"), matmul: T.Buffer((T.int64(1), T.int64(32)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i0, i1, k in T.grid(T.int64(1), T.int64(32), T.int64(16)):
            with T.block("matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(x[v_i0, v_k], weight[v_k, v_i1])
                T.writes(matmul[v_i0, v_i1])
                with T.init():
                    matmul[v_i0, v_i1] = T.float32(0)
                matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + x[v_i0, v_k] * weight[v_k, v_i1]

    @T.prim_func(private=True)
    def reshape(gv: T.Buffer((T.int64(16), T.int64(16)), "float32"), T_reshape: T.Buffer((T.int64(256),), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(256)):
            with T.block("T_reshape"):
                v_ax0 = T.axis.spatial(T.int64(256), ax0)
                T.reads(gv[v_ax0 % T.int64(256) // T.int64(16), v_ax0 % T.int64(16)])
                T.writes(T_reshape[v_ax0])
                T_reshape[v_ax0] = gv[v_ax0 % T.int64(256) // T.int64(16), v_ax0 % T.int64(16)]

    @T.prim_func(private=True)
    def reshape1(temp: T.Buffer((T.int64(16),), "float32"), T_reshape: T.Buffer((T.int64(1), T.int64(16)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(1), T.int64(16)):
            with T.block("T_reshape"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(temp[v_ax1 % T.int64(16)])
                T.writes(T_reshape[v_ax0, v_ax1])
                T_reshape[v_ax0, v_ax1] = temp[v_ax1 % T.int64(16)]

    @T.prim_func(private=True)
    def reshape2(gv: T.Buffer((T.int64(16), T.int64(16)), "int64"), T_reshape: T.Buffer((T.int64(256),), "int64")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(256)):
            with T.block("T_reshape"):
                v_ax0 = T.axis.spatial(T.int64(256), ax0)
                T.reads(gv[v_ax0 % T.int64(256) // T.int64(16), v_ax0 % T.int64(16)])
                T.writes(T_reshape[v_ax0])
                T_reshape[v_ax0] = gv[v_ax0 % T.int64(256) // T.int64(16), v_ax0 % T.int64(16)]

    @T.prim_func(private=True)
    def reshape3(temp: T.Buffer((T.int64(32),), "int64"), T_reshape: T.Buffer((T.int64(32),), "int64")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(32)):
            with T.block("T_reshape"):
                v_ax0 = T.axis.spatial(T.int64(32), ax0)
                T.reads(temp[v_ax0 % T.int64(32)])
                T.writes(T_reshape[v_ax0])
                T_reshape[v_ax0] = temp[v_ax0 % T.int64(32)]

    @T.prim_func(private=True)
    def strided_slice(tensor_1dim: T.Buffer((T.int64(256),), "float32"), T_strided_slice_with_axes: T.Buffer((T.int64(16),), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(16)):
            with T.block("T_strided_slice_with_axes"):
                v_ax0 = T.axis.spatial(T.int64(16), ax0)
                T.reads(tensor_1dim[v_ax0])
                T.writes(T_strided_slice_with_axes[v_ax0])
                T_strided_slice_with_axes[v_ax0] = tensor_1dim[v_ax0]

    @T.prim_func(private=True)
    def strided_slice1(tensor_1dim: T.Buffer((T.int64(256),), "int64"), T_strided_slice_with_axes: T.Buffer((T.int64(32),), "int64")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(32)):
            with T.block("T_strided_slice_with_axes"):
                v_ax0 = T.axis.spatial(T.int64(32), ax0)
                T.reads(tensor_1dim[v_ax0])
                T.writes(T_strided_slice_with_axes[v_ax0])
                T_strided_slice_with_axes[v_ax0] = tensor_1dim[v_ax0]

    @T.prim_func(private=True)
    def take(var_weight_table: T.handle, routing_table: T.Buffer((T.int64(32),), "int64"), T_take: T.Buffer((T.int64(16), T.int64(32)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        weight_table_size = T.int64()
        weight_table = T.match_buffer(var_weight_table, (T.int64(16), weight_table_size))
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(16), T.int64(32)):
            with T.block("T_take"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(weight_table[v_ax0, routing_table[v_ax1]], routing_table[v_ax1])
                T.writes(T_take[v_ax0, v_ax1])
                T_take[v_ax0, v_ax1] = weight_table[v_ax0, routing_table[v_ax1]]

    @R.function
    def main_7(x: R.Tensor((1, 16), dtype="float32"), weight_table: R.Tensor((16, "weight_table_size"), dtype="float32"), routing_table: R.Tensor((32,), dtype="int64")) -> R.Tensor((1, 32), dtype="float32"):
        weight_table_size = T.int64()
        cls = Module
        with R.dataflow():
            weight = R.call_tir(cls.take, (weight_table, routing_table), out_sinfo=R.Tensor((16, 32), dtype="float32"))
            out = R.call_tir(cls.matmul, (x, weight), out_sinfo=R.Tensor((1, 32), dtype="float32"))
            R.output(out)
        return out

    @R.function
    def main() -> R.Tensor((1, 32), dtype="float32"):
        cls = Module
        gv = R.call_tir(cls.add, (metadata["relax.expr.Constant"][0], metadata["relax.expr.Constant"][0]), out_sinfo=R.Tensor((16, 16), dtype="float32"))
        tensor_1dim = R.call_tir(cls.reshape, (gv,), out_sinfo=R.Tensor((256,), dtype="float32"))
        temp = R.call_tir(cls.strided_slice, (tensor_1dim,), out_sinfo=R.Tensor((16,), dtype="float32"))
        para0 = R.call_tir(cls.reshape1, (temp,), out_sinfo=R.Tensor((1, 16), dtype="float32"))
        para1: R.Tensor((16, 16), dtype="float32") = gv
        gv_1 = R.call_tir(cls.cast, (gv,), out_sinfo=R.Tensor((16, 16), dtype="int64"))
        tensor_1dim_1 = R.call_tir(cls.reshape2, (gv_1,), out_sinfo=R.Tensor((256,), dtype="int64"))
        temp_1 = R.call_tir(cls.strided_slice1, (tensor_1dim_1,), out_sinfo=R.Tensor((32,), dtype="int64"))
        para2 = R.call_tir(cls.reshape3, (temp_1,), out_sinfo=R.Tensor((32,), dtype="int64"))
        res: R.Tensor((1, 32), dtype="float32") = cls.main_7(para0, para1, para2)
        return res


def compile_mod(mod, func_name, target, *inputs):
    ex = relax.build(mod, target='llvm')
    vm = relax.VirtualMachine(ex, tvm.cpu())
    mod_outputs = vm[f'{func_name}'](*inputs)
    mod_outputs = mod_outputs.numpy()
    return mod_outputs
mod = Module
before_outputs = compile_mod(mod, 'main', 'llvm')
mod = relax.transform.FoldConstant()(mod)
mod = relax.transform.ReorderTakeAfterMatmul()(mod)
after_outputs = compile_mod(mod, 'main', 'llvm')
np.testing.assert_allclose(before_outputs, after_outputs, 1e-3, 1e-3)

CC @Lunderberg @junrushao

Cookiee235 avatar Aug 12 '24 15:08 Cookiee235

The inconsistent results occur even if the FoldConstant and ReorderTakeAfterMatmul lines are removed. This looks like there's some out-of-bounds access in the PrimFuncs. If I replace the TIR functions with their Relax equivalents, I get something as follows, which shows the same inconsistent outputs:

class Module:
    @R.function
    def main():
        metadata_constant = R.reshape(R.arange(256), [16, 16]).astype("float32")
        weight_table: R.Tensor([16, 16], "float32") = metadata_constant + metadata_constant
        x = R.strided_slice(
            R.reshape(weight_table, [256]),
            axes=[0],
            begin=[0],
            end=[16],
        )
        indices = weight_table.astype("int64")
        routing_table = R.strided_slice(
            R.reshape(indices, [256]),
            axes=[0],
            begin=[0],
            end=[32],
        )
        weight = R.take(weight_table, routing_table, axis=1)
        out = R.matmul(x, weight)
        return out

The initial constant from the metadata has values within [0,256). After adding it to itself and taking the first 32 indices, the values range from [0,64). However, the weight table is of shape [16,16], and these indices are used to access axis 1. This out-of-bounds access results in inconsistent outputs.

Lunderberg avatar Aug 20 '24 18:08 Lunderberg