tvm icon indicating copy to clipboard operation
tvm copied to clipboard

[Bug] Check failed: (prim_func->body->IsInstance<tir::BlockRealizeNode>()) is false

Open Cookiee235 opened this issue 1 year ago • 3 comments
trafficstars

Actual behavior

Traceback (most recent call last):
  File "/share_container/optfuzz/res/bugs/reduced/complete/1954_test.py", line 252, in <module>
    mod = relax.transform.FuseTIR()(mod)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm/python/tvm/ir/transform.py", line 270, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/software/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
tvm.error.InternalError: Traceback (most recent call last):
  23: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  22: tvm::transform::Pass::operator()(tvm::IRModule) const
  21: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  20: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  19: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  18: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  17: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_8IRModuleES5_NS_9transform11PassContextEEE17AssignTypedLambdaIZNS_5relax9transform7FuseTIREvEUlS5_S7_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SF_SJ_
  16: tvm::relax::FuseTIR(tvm::IRModule)
  15: tvm::relax::TIRFuseMutator::Transform(tvm::IRModule)
  14: tvm::relax::FusedTIRConstructor::GetFusedTIR(tvm::IRModule const&, tvm::GlobalVar const&)
  13: tvm::relax::ExprVisitor::VisitExpr(tvm::RelayExpr const&)
  12: _ZZN3tvm5relax11ExprFunctorIFvRKNS_9RelayExp
  11: tvm::relax::FusedTIRConstructor::VisitExpr_(tvm::relax::FunctionNode const*)
  10: tvm::relax::ExprVisitor::VisitExpr_(tvm::relax::FunctionNode const*)
  9: tvm::relax::ExprVisitor::VisitExpr(tvm::RelayExpr const&)
  8: _ZZN3tvm5relax11ExprFunctorIFvRKNS_9RelayExp
  7: tvm::relax::ExprVisitor::VisitExpr_(tvm::relax::SeqExprNode const*)
  6: tvm::relax::ExprVisitor::VisitBindingBlock(tvm::relax::BindingBlock const&)
  5: tvm::relax::ExprVisitor::VisitBindingBlock_(tvm::relax::DataflowBlockNode const*)
  4: tvm::relax::ExprVisitor::VisitBinding(tvm::relax::Binding const&)
  3: tvm::relax::FusedTIRConstructor::VisitBinding_(tvm::relax::VarBindingNode const*)
  2: tvm::relax::ExprVisitor::VisitExpr(tvm::RelayExpr const&)
  1: _ZZN3tvm5relax11ExprFunctorIFvRKNS_9RelayExp
  0: tvm::relax::FusedTIRConstructor::VisitExpr_(tvm::relax::CallNode const*)
  File "/software/tvm/src/relax/transform/fuse_tir.cc", line 527
InternalError: Check failed: (prim_func->body->IsInstance<tir::BlockRealizeNode>()) is false: Only schedulable functions (whose body is the root block) can be fused

Steps to reproduce

import tvm
from tvm import relax

from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func
    def add(rxplaceholder: T.Buffer((T.int64(8),), "float32"), rxplaceholder_1: T.Buffer((), "float32"), T_add: T.Buffer((T.int64(8),), "float32")):
        T.evaluate(0)

    @T.prim_func(private=True)
    def concatenate(tensor_1dim: T.Buffer((T.int64(10),), "float32"), pad_tensor: T.Buffer((T.int64(3211254),), "float32"), T_concat: T.Buffer((T.int64(3211264),), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(3211264)):
            with T.block("T_concat"):
                v_ax0 = T.axis.spatial(T.int64(3211264), ax0)
                T.reads(pad_tensor[v_ax0 - T.int64(10)], tensor_1dim[v_ax0])
                T.writes(T_concat[v_ax0])
                T_concat[v_ax0] = T.if_then_else(T.int64(10) <= v_ax0, pad_tensor[v_ax0 - T.int64(10)], tensor_1dim[v_ax0])

    @T.prim_func(private=True)
    def concatenate1(tensor_1dim: T.Buffer((T.int64(10),), "float32"), pad_tensor: T.Buffer((T.int64(12534),), "float32"), T_concat: T.Buffer((T.int64(12544),), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(12544)):
            with T.block("T_concat"):
                v_ax0 = T.axis.spatial(T.int64(12544), ax0)
                T.reads(pad_tensor[v_ax0 - T.int64(10)], tensor_1dim[v_ax0])
                T.writes(T_concat[v_ax0])
                T_concat[v_ax0] = T.if_then_else(T.int64(10) <= v_ax0, pad_tensor[v_ax0 - T.int64(10)], tensor_1dim[v_ax0])

    @T.prim_func(private=True)
    def concatenate2(tensor_1dim: T.Buffer((T.int64(6),), "float32"), pad_tensor: T.Buffer((T.int64(2),), "float32"), T_concat: T.Buffer((T.int64(8),), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(8)):
            with T.block("T_concat"):
                v_ax0 = T.axis.spatial(T.int64(8), ax0)
                T.reads(pad_tensor[v_ax0 - T.int64(6)], tensor_1dim[v_ax0])
                T.writes(T_concat[v_ax0])
                T_concat[v_ax0] = T.if_then_else(T.int64(6) <= v_ax0, pad_tensor[v_ax0 - T.int64(6)], tensor_1dim[v_ax0])

    @T.prim_func(private=True)
    def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")):
        T.evaluate(0)

    @T.prim_func
    def exp_8(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), compute: T.Buffer((T.int64(2), T.int64(4)), "float32")):
        T.evaluate(0)

    @T.prim_func(private=True)
    def layer_norm_2(x: T.Buffer((T.int64(4), T.int64(64), T.int64(112), T.int64(112)), "float32"), gamma: T.Buffer((T.int64(112), T.int64(112)), "float32"), beta: T.Buffer((T.int64(112), T.int64(112)), "float32"), T_layer_norm: T.Buffer((T.int64(4), T.int64(64), T.int64(112), T.int64(112)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        x_red_temp_v0 = T.alloc_buffer((T.int64(4), T.int64(64)))
        x_red_temp_v1 = T.alloc_buffer((T.int64(4), T.int64(64)))
        for ax0, ax1, k2, k3 in T.grid(T.int64(4), T.int64(64), T.int64(112), T.int64(112)):
            with T.block("x_red_temp"):
                v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1, k2, k3])
                T.reads(x[v_ax0, v_ax1, v_k2, v_k3])
                T.writes(x_red_temp_v0[v_ax0, v_ax1], x_red_temp_v1[v_ax0, v_ax1])
                with T.init():
                    x_red_temp_v0[v_ax0, v_ax1] = T.float32(0)
                    x_red_temp_v1[v_ax0, v_ax1] = T.float32(0)
                v_x_red_temp_v0: T.float32 = x_red_temp_v0[v_ax0, v_ax1] + x[v_ax0, v_ax1, v_k2, v_k3]
                v_x_red_temp_v1: T.float32 = x_red_temp_v1[v_ax0, v_ax1] + x[v_ax0, v_ax1, v_k2, v_k3] * x[v_ax0, v_ax1, v_k2, v_k3]
                x_red_temp_v0[v_ax0, v_ax1] = v_x_red_temp_v0
                x_red_temp_v1[v_ax0, v_ax1] = v_x_red_temp_v1
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(64), T.int64(112), T.int64(112)):
            with T.block("T_layer_norm"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], x_red_temp_v0[v_ax0, v_ax1], x_red_temp_v1[v_ax0, v_ax1], gamma[v_ax2, v_ax3], beta[v_ax2, v_ax3])
                T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3])
                T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3] = (x[v_ax0, v_ax1, v_ax2, v_ax3] - x_red_temp_v0[v_ax0, v_ax1] * T.float32(7.9719387755102034e-05)) * T.rsqrt(x_red_temp_v1[v_ax0, v_ax1] * T.float32(7.9719387755102034e-05) - x_red_temp_v0[v_ax0, v_ax1] * T.float32(7.9719387755102034e-05) * (x_red_temp_v0[v_ax0, v_ax1] * T.float32(7.9719387755102034e-05)) + T.float32(1.0000000000000001e-05)) * gamma[v_ax2, v_ax3] + beta[v_ax2, v_ax3]

    @T.prim_func
    def log(rxplaceholder: T.Buffer((T.int64(10),), "float32"), compute: T.Buffer((T.int64(10),), "float32")):
        T.evaluate(0)

    @T.prim_func
    def pad(rxplaceholder: T.Buffer((T.int64(8),), "float32"), PadInput: T.Buffer((T.int64(10),), "float32")):
        T.evaluate(0)

    @T.prim_func
    def relu_10(rxplaceholder: T.Buffer((T.int64(8),), "float32"), compute: T.Buffer((T.int64(8),), "float32")):
        T.evaluate(0)

    @T.prim_func
    def reshape_6(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), T_reshape: T.Buffer((T.int64(8),), "float32")):
        T.evaluate(0)

    @T.prim_func(private=True)
    def reshape_62(gv: T.Buffer((T.int64(10),), "float32"), T_reshape: T.Buffer((T.int64(10),), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(10)):
            with T.block("T_reshape"):
                v_ax0 = T.axis.spatial(T.int64(10), ax0)
                T.reads(gv[v_ax0 % T.int64(10)])
                T.writes(T_reshape[v_ax0])
                T_reshape[v_ax0] = gv[v_ax0 % T.int64(10)]

    @T.prim_func(private=True)
    def reshape_63(temp: T.Buffer((T.int64(3211264),), "float32"), T_reshape: T.Buffer((T.int64(4), T.int64(64), T.int64(112), T.int64(112)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(64), T.int64(112), T.int64(112)):
            with T.block("T_reshape"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(temp[(v_ax0 * T.int64(802816) + v_ax1 * T.int64(12544) + v_ax2 * T.int64(112) + v_ax3) % T.int64(3211264)])
                T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
                T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = temp[(v_ax0 * T.int64(802816) + v_ax1 * T.int64(12544) + v_ax2 * T.int64(112) + v_ax3) % T.int64(3211264)]

    @T.prim_func(private=True)
    def reshape_64(temp: T.Buffer((T.int64(12544),), "float32"), T_reshape: T.Buffer((T.int64(112), T.int64(112)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(112), T.int64(112)):
            with T.block("T_reshape"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(temp[(v_ax0 * T.int64(112) + v_ax1) % T.int64(12544)])
                T.writes(T_reshape[v_ax0, v_ax1])
                T_reshape[v_ax0, v_ax1] = temp[(v_ax0 * T.int64(112) + v_ax1) % T.int64(12544)]

    @T.prim_func(private=True)
    def reshape_66(temp: T.Buffer((T.int64(8),), "float32"), T_reshape: T.Buffer((T.int64(2), T.int64(4)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(2), T.int64(4)):
            with T.block("T_reshape"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(temp[(v_ax0 * T.int64(4) + v_ax1) % T.int64(8)])
                T.writes(T_reshape[v_ax0, v_ax1])
                T_reshape[v_ax0, v_ax1] = temp[(v_ax0 * T.int64(4) + v_ax1) % T.int64(8)]

    @T.prim_func(private=True)
    def reshape_67(b: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_reshape: T.Buffer((T.int64(6),), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(6)):
            with T.block("T_reshape"):
                v_ax0 = T.axis.spatial(T.int64(6), ax0)
                T.reads(b[v_ax0 % T.int64(6) // T.int64(3), v_ax0 % T.int64(3)])
                T.writes(T_reshape[v_ax0])
                T_reshape[v_ax0] = b[v_ax0 % T.int64(6) // T.int64(3), v_ax0 % T.int64(3)]

    @T.prim_func
    def tir_matmul(A: T.Buffer((32, 32), "float32"), B: T.Buffer((32, 32), "float32"), C: T.Buffer((32, 32), "float32")):
        # with T.block("root"):
        for i0, j0, k0 in T.grid(32, 32, 32):
            with T.block(""):
                i, j, k = T.axis.remap("SSR", [i0, j0, k0])
                T.reads(A[i, k], B[j, k])
                T.writes(C[i, j])
                with T.init():
                    C[i, j] = T.float32(0)
                C[i, j] = C[i, j] + A[i, k] * B[j, k]

    @T.prim_func
    def tir_relu(A: T.Buffer((32, 32), "float32"), B: T.Buffer((32, 32), "float32")):
        # with T.block("root"):
        for i, j in T.grid(32, 32):
            with T.block(""):
                vi, vj = T.axis.remap("SS", [i, j])
                T.reads(A[vi, vj])
                T.writes(B[vi, vj])
                B[vi, vj] = T.max(A[vi, vj], T.float32(0))

    @T.prim_func(private=True)
    def zeros(T_full: T.Buffer((T.int64(3211254),), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(3211254)):
            with T.block("T_full"):
                v_ax0 = T.axis.spatial(T.int64(3211254), ax0)
                T.reads()
                T.writes(T_full[v_ax0])
                T_full[v_ax0] = T.float32(0)

    @T.prim_func(private=True)
    def zeros1(T_full: T.Buffer((T.int64(12534),), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(12534)):
            with T.block("T_full"):
                v_ax0 = T.axis.spatial(T.int64(12534), ax0)
                T.reads()
                T.writes(T_full[v_ax0])
                T_full[v_ax0] = T.float32(0)

    @T.prim_func(private=True)
    def zeros2(T_full: T.Buffer((T.int64(2),), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(2)):
            with T.block("T_full"):
                v_ax0 = T.axis.spatial(T.int64(2), ax0)
                T.reads()
                T.writes(T_full[v_ax0])
                T_full[v_ax0] = T.float32(0)

    @R.function
    def main_0(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((4, 64, 112, 112), dtype="float32"):
        cls = Module
        with R.dataflow():
            a = R.call_tir(cls.exp, (x,), out_sinfo=R.Tensor((2, 3), dtype="float32"))
            b = R.call_tir(cls.exp, (a,), out_sinfo=R.Tensor((2, 3), dtype="float32"))
            tensor_1dim = R.call_tir(cls.reshape_67, (b,), out_sinfo=R.Tensor((6,), dtype="float32"))
            pad_tensor = R.call_tir(cls.zeros2, R.tuple(), out_sinfo=R.Tensor((2,), dtype="float32"))
            temp = R.call_tir(cls.concatenate2, (tensor_1dim, pad_tensor), out_sinfo=R.Tensor((8,), dtype="float32"))
            para0 = R.call_tir(cls.reshape_66, (temp,), out_sinfo=R.Tensor((2, 4), dtype="float32"))
            res: R.Tensor((4, 64, 112, 112), dtype="float32") = cls.main_0_9_0(para0)
            R.output(res)
        return res

    @R.function
    def main_0_9_0(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((4, 64, 112, 112), dtype="float32"):
        R.func_attr({"relax.force_pure": 1})
        cls = Module
        alloc4: R.Tensor((10,), dtype="float32") = R.builtin.alloc_tensor(R.shape([10]), R.dtype("float32"), R.prim_value(0), R.str("global"))
        gv: R.Tensor((10,), dtype="float32") = alloc4
        tensor_1dim = R.call_tir(cls.reshape_62, (gv,), out_sinfo=R.Tensor((10,), dtype="float32"))
        pad_tensor = R.call_tir(cls.zeros, R.tuple(), out_sinfo=R.Tensor((3211254,), dtype="float32"))
        temp = R.call_tir(cls.concatenate, (tensor_1dim, pad_tensor), out_sinfo=R.Tensor((3211264,), dtype="float32"))
        para0 = R.call_tir(cls.reshape_63, (temp,), out_sinfo=R.Tensor((4, 64, 112, 112), dtype="float32"))
        tensor_1dim_1 = R.call_tir(cls.reshape_62, (gv,), out_sinfo=R.Tensor((10,), dtype="float32"))
        pad_tensor_1 = R.call_tir(cls.zeros1, R.tuple(), out_sinfo=R.Tensor((12534,), dtype="float32"))
        temp_1 = R.call_tir(cls.concatenate1, (tensor_1dim_1, pad_tensor_1), out_sinfo=R.Tensor((12544,), dtype="float32"))
        para1 = R.call_tir(cls.reshape_64, (temp_1,), out_sinfo=R.Tensor((112, 112), dtype="float32"))
        tensor_1dim_2 = R.call_tir(cls.reshape_62, (gv,), out_sinfo=R.Tensor((10,), dtype="float32"))
        pad_tensor_2 = R.call_tir(cls.zeros1, R.tuple(), out_sinfo=R.Tensor((12534,), dtype="float32"))
        temp_2 = R.call_tir(cls.concatenate1, (tensor_1dim_2, pad_tensor_2), out_sinfo=R.Tensor((12544,), dtype="float32"))
        para2 = R.call_tir(cls.reshape_64, (temp_2,), out_sinfo=R.Tensor((112, 112), dtype="float32"))
        res: R.Tensor((4, 64, 112, 112), dtype="float32") = cls.main_0_9_0_2(para0, para1, para2)
        return res

    @R.function
    def main_0_9_0_2(x: R.Tensor((4, 64, 112, 112), dtype="float32"), gamma: R.Tensor((112, 112), dtype="float32"), beta: R.Tensor((112, 112), dtype="float32")) -> R.Tensor((4, 64, 112, 112), dtype="float32"):
        cls = Module
        with R.dataflow():
            ln = R.call_tir(cls.layer_norm_2, (x, gamma, beta), out_sinfo=R.Tensor((4, 64, 112, 112), dtype="float32"))
            R.output(ln)
        return ln

mod = Module
mod = relax.transform.AnnotateTIROpPattern()(mod)
mod = relax.transform.FuseOps()(mod)
mod = relax.transform.FuseTIR()(mod)

Cookiee235 avatar Sep 05 '24 14:09 Cookiee235