tvm
tvm copied to clipboard
[Bug] Check failed: (prim_func->body->IsInstance<tir::BlockRealizeNode>()) is false
trafficstars
Actual behavior
Traceback (most recent call last):
File "/share_container/optfuzz/res/bugs/reduced/complete/1954_test.py", line 252, in <module>
mod = relax.transform.FuseTIR()(mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm/python/tvm/ir/transform.py", line 270, in __call__
return _ffi_transform_api.RunPass(self, mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
raise_last_ffi_error()
File "/software/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
raise py_err
tvm.error.InternalError: Traceback (most recent call last):
23: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
22: tvm::transform::Pass::operator()(tvm::IRModule) const
21: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
20: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
19: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
18: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
17: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_8IRModuleES5_NS_9transform11PassContextEEE17AssignTypedLambdaIZNS_5relax9transform7FuseTIREvEUlS5_S7_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SF_SJ_
16: tvm::relax::FuseTIR(tvm::IRModule)
15: tvm::relax::TIRFuseMutator::Transform(tvm::IRModule)
14: tvm::relax::FusedTIRConstructor::GetFusedTIR(tvm::IRModule const&, tvm::GlobalVar const&)
13: tvm::relax::ExprVisitor::VisitExpr(tvm::RelayExpr const&)
12: _ZZN3tvm5relax11ExprFunctorIFvRKNS_9RelayExp
11: tvm::relax::FusedTIRConstructor::VisitExpr_(tvm::relax::FunctionNode const*)
10: tvm::relax::ExprVisitor::VisitExpr_(tvm::relax::FunctionNode const*)
9: tvm::relax::ExprVisitor::VisitExpr(tvm::RelayExpr const&)
8: _ZZN3tvm5relax11ExprFunctorIFvRKNS_9RelayExp
7: tvm::relax::ExprVisitor::VisitExpr_(tvm::relax::SeqExprNode const*)
6: tvm::relax::ExprVisitor::VisitBindingBlock(tvm::relax::BindingBlock const&)
5: tvm::relax::ExprVisitor::VisitBindingBlock_(tvm::relax::DataflowBlockNode const*)
4: tvm::relax::ExprVisitor::VisitBinding(tvm::relax::Binding const&)
3: tvm::relax::FusedTIRConstructor::VisitBinding_(tvm::relax::VarBindingNode const*)
2: tvm::relax::ExprVisitor::VisitExpr(tvm::RelayExpr const&)
1: _ZZN3tvm5relax11ExprFunctorIFvRKNS_9RelayExp
0: tvm::relax::FusedTIRConstructor::VisitExpr_(tvm::relax::CallNode const*)
File "/software/tvm/src/relax/transform/fuse_tir.cc", line 527
InternalError: Check failed: (prim_func->body->IsInstance<tir::BlockRealizeNode>()) is false: Only schedulable functions (whose body is the root block) can be fused
Steps to reproduce
import tvm
from tvm import relax
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func
def add(rxplaceholder: T.Buffer((T.int64(8),), "float32"), rxplaceholder_1: T.Buffer((), "float32"), T_add: T.Buffer((T.int64(8),), "float32")):
T.evaluate(0)
@T.prim_func(private=True)
def concatenate(tensor_1dim: T.Buffer((T.int64(10),), "float32"), pad_tensor: T.Buffer((T.int64(3211254),), "float32"), T_concat: T.Buffer((T.int64(3211264),), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(3211264)):
with T.block("T_concat"):
v_ax0 = T.axis.spatial(T.int64(3211264), ax0)
T.reads(pad_tensor[v_ax0 - T.int64(10)], tensor_1dim[v_ax0])
T.writes(T_concat[v_ax0])
T_concat[v_ax0] = T.if_then_else(T.int64(10) <= v_ax0, pad_tensor[v_ax0 - T.int64(10)], tensor_1dim[v_ax0])
@T.prim_func(private=True)
def concatenate1(tensor_1dim: T.Buffer((T.int64(10),), "float32"), pad_tensor: T.Buffer((T.int64(12534),), "float32"), T_concat: T.Buffer((T.int64(12544),), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(12544)):
with T.block("T_concat"):
v_ax0 = T.axis.spatial(T.int64(12544), ax0)
T.reads(pad_tensor[v_ax0 - T.int64(10)], tensor_1dim[v_ax0])
T.writes(T_concat[v_ax0])
T_concat[v_ax0] = T.if_then_else(T.int64(10) <= v_ax0, pad_tensor[v_ax0 - T.int64(10)], tensor_1dim[v_ax0])
@T.prim_func(private=True)
def concatenate2(tensor_1dim: T.Buffer((T.int64(6),), "float32"), pad_tensor: T.Buffer((T.int64(2),), "float32"), T_concat: T.Buffer((T.int64(8),), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(8)):
with T.block("T_concat"):
v_ax0 = T.axis.spatial(T.int64(8), ax0)
T.reads(pad_tensor[v_ax0 - T.int64(6)], tensor_1dim[v_ax0])
T.writes(T_concat[v_ax0])
T_concat[v_ax0] = T.if_then_else(T.int64(6) <= v_ax0, pad_tensor[v_ax0 - T.int64(6)], tensor_1dim[v_ax0])
@T.prim_func(private=True)
def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")):
T.evaluate(0)
@T.prim_func
def exp_8(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), compute: T.Buffer((T.int64(2), T.int64(4)), "float32")):
T.evaluate(0)
@T.prim_func(private=True)
def layer_norm_2(x: T.Buffer((T.int64(4), T.int64(64), T.int64(112), T.int64(112)), "float32"), gamma: T.Buffer((T.int64(112), T.int64(112)), "float32"), beta: T.Buffer((T.int64(112), T.int64(112)), "float32"), T_layer_norm: T.Buffer((T.int64(4), T.int64(64), T.int64(112), T.int64(112)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
x_red_temp_v0 = T.alloc_buffer((T.int64(4), T.int64(64)))
x_red_temp_v1 = T.alloc_buffer((T.int64(4), T.int64(64)))
for ax0, ax1, k2, k3 in T.grid(T.int64(4), T.int64(64), T.int64(112), T.int64(112)):
with T.block("x_red_temp"):
v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1, k2, k3])
T.reads(x[v_ax0, v_ax1, v_k2, v_k3])
T.writes(x_red_temp_v0[v_ax0, v_ax1], x_red_temp_v1[v_ax0, v_ax1])
with T.init():
x_red_temp_v0[v_ax0, v_ax1] = T.float32(0)
x_red_temp_v1[v_ax0, v_ax1] = T.float32(0)
v_x_red_temp_v0: T.float32 = x_red_temp_v0[v_ax0, v_ax1] + x[v_ax0, v_ax1, v_k2, v_k3]
v_x_red_temp_v1: T.float32 = x_red_temp_v1[v_ax0, v_ax1] + x[v_ax0, v_ax1, v_k2, v_k3] * x[v_ax0, v_ax1, v_k2, v_k3]
x_red_temp_v0[v_ax0, v_ax1] = v_x_red_temp_v0
x_red_temp_v1[v_ax0, v_ax1] = v_x_red_temp_v1
for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(64), T.int64(112), T.int64(112)):
with T.block("T_layer_norm"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], x_red_temp_v0[v_ax0, v_ax1], x_red_temp_v1[v_ax0, v_ax1], gamma[v_ax2, v_ax3], beta[v_ax2, v_ax3])
T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3])
T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3] = (x[v_ax0, v_ax1, v_ax2, v_ax3] - x_red_temp_v0[v_ax0, v_ax1] * T.float32(7.9719387755102034e-05)) * T.rsqrt(x_red_temp_v1[v_ax0, v_ax1] * T.float32(7.9719387755102034e-05) - x_red_temp_v0[v_ax0, v_ax1] * T.float32(7.9719387755102034e-05) * (x_red_temp_v0[v_ax0, v_ax1] * T.float32(7.9719387755102034e-05)) + T.float32(1.0000000000000001e-05)) * gamma[v_ax2, v_ax3] + beta[v_ax2, v_ax3]
@T.prim_func
def log(rxplaceholder: T.Buffer((T.int64(10),), "float32"), compute: T.Buffer((T.int64(10),), "float32")):
T.evaluate(0)
@T.prim_func
def pad(rxplaceholder: T.Buffer((T.int64(8),), "float32"), PadInput: T.Buffer((T.int64(10),), "float32")):
T.evaluate(0)
@T.prim_func
def relu_10(rxplaceholder: T.Buffer((T.int64(8),), "float32"), compute: T.Buffer((T.int64(8),), "float32")):
T.evaluate(0)
@T.prim_func
def reshape_6(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), T_reshape: T.Buffer((T.int64(8),), "float32")):
T.evaluate(0)
@T.prim_func(private=True)
def reshape_62(gv: T.Buffer((T.int64(10),), "float32"), T_reshape: T.Buffer((T.int64(10),), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(10)):
with T.block("T_reshape"):
v_ax0 = T.axis.spatial(T.int64(10), ax0)
T.reads(gv[v_ax0 % T.int64(10)])
T.writes(T_reshape[v_ax0])
T_reshape[v_ax0] = gv[v_ax0 % T.int64(10)]
@T.prim_func(private=True)
def reshape_63(temp: T.Buffer((T.int64(3211264),), "float32"), T_reshape: T.Buffer((T.int64(4), T.int64(64), T.int64(112), T.int64(112)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(64), T.int64(112), T.int64(112)):
with T.block("T_reshape"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(temp[(v_ax0 * T.int64(802816) + v_ax1 * T.int64(12544) + v_ax2 * T.int64(112) + v_ax3) % T.int64(3211264)])
T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = temp[(v_ax0 * T.int64(802816) + v_ax1 * T.int64(12544) + v_ax2 * T.int64(112) + v_ax3) % T.int64(3211264)]
@T.prim_func(private=True)
def reshape_64(temp: T.Buffer((T.int64(12544),), "float32"), T_reshape: T.Buffer((T.int64(112), T.int64(112)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(112), T.int64(112)):
with T.block("T_reshape"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(temp[(v_ax0 * T.int64(112) + v_ax1) % T.int64(12544)])
T.writes(T_reshape[v_ax0, v_ax1])
T_reshape[v_ax0, v_ax1] = temp[(v_ax0 * T.int64(112) + v_ax1) % T.int64(12544)]
@T.prim_func(private=True)
def reshape_66(temp: T.Buffer((T.int64(8),), "float32"), T_reshape: T.Buffer((T.int64(2), T.int64(4)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(2), T.int64(4)):
with T.block("T_reshape"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(temp[(v_ax0 * T.int64(4) + v_ax1) % T.int64(8)])
T.writes(T_reshape[v_ax0, v_ax1])
T_reshape[v_ax0, v_ax1] = temp[(v_ax0 * T.int64(4) + v_ax1) % T.int64(8)]
@T.prim_func(private=True)
def reshape_67(b: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_reshape: T.Buffer((T.int64(6),), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(6)):
with T.block("T_reshape"):
v_ax0 = T.axis.spatial(T.int64(6), ax0)
T.reads(b[v_ax0 % T.int64(6) // T.int64(3), v_ax0 % T.int64(3)])
T.writes(T_reshape[v_ax0])
T_reshape[v_ax0] = b[v_ax0 % T.int64(6) // T.int64(3), v_ax0 % T.int64(3)]
@T.prim_func
def tir_matmul(A: T.Buffer((32, 32), "float32"), B: T.Buffer((32, 32), "float32"), C: T.Buffer((32, 32), "float32")):
# with T.block("root"):
for i0, j0, k0 in T.grid(32, 32, 32):
with T.block(""):
i, j, k = T.axis.remap("SSR", [i0, j0, k0])
T.reads(A[i, k], B[j, k])
T.writes(C[i, j])
with T.init():
C[i, j] = T.float32(0)
C[i, j] = C[i, j] + A[i, k] * B[j, k]
@T.prim_func
def tir_relu(A: T.Buffer((32, 32), "float32"), B: T.Buffer((32, 32), "float32")):
# with T.block("root"):
for i, j in T.grid(32, 32):
with T.block(""):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A[vi, vj])
T.writes(B[vi, vj])
B[vi, vj] = T.max(A[vi, vj], T.float32(0))
@T.prim_func(private=True)
def zeros(T_full: T.Buffer((T.int64(3211254),), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(3211254)):
with T.block("T_full"):
v_ax0 = T.axis.spatial(T.int64(3211254), ax0)
T.reads()
T.writes(T_full[v_ax0])
T_full[v_ax0] = T.float32(0)
@T.prim_func(private=True)
def zeros1(T_full: T.Buffer((T.int64(12534),), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(12534)):
with T.block("T_full"):
v_ax0 = T.axis.spatial(T.int64(12534), ax0)
T.reads()
T.writes(T_full[v_ax0])
T_full[v_ax0] = T.float32(0)
@T.prim_func(private=True)
def zeros2(T_full: T.Buffer((T.int64(2),), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(2)):
with T.block("T_full"):
v_ax0 = T.axis.spatial(T.int64(2), ax0)
T.reads()
T.writes(T_full[v_ax0])
T_full[v_ax0] = T.float32(0)
@R.function
def main_0(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((4, 64, 112, 112), dtype="float32"):
cls = Module
with R.dataflow():
a = R.call_tir(cls.exp, (x,), out_sinfo=R.Tensor((2, 3), dtype="float32"))
b = R.call_tir(cls.exp, (a,), out_sinfo=R.Tensor((2, 3), dtype="float32"))
tensor_1dim = R.call_tir(cls.reshape_67, (b,), out_sinfo=R.Tensor((6,), dtype="float32"))
pad_tensor = R.call_tir(cls.zeros2, R.tuple(), out_sinfo=R.Tensor((2,), dtype="float32"))
temp = R.call_tir(cls.concatenate2, (tensor_1dim, pad_tensor), out_sinfo=R.Tensor((8,), dtype="float32"))
para0 = R.call_tir(cls.reshape_66, (temp,), out_sinfo=R.Tensor((2, 4), dtype="float32"))
res: R.Tensor((4, 64, 112, 112), dtype="float32") = cls.main_0_9_0(para0)
R.output(res)
return res
@R.function
def main_0_9_0(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((4, 64, 112, 112), dtype="float32"):
R.func_attr({"relax.force_pure": 1})
cls = Module
alloc4: R.Tensor((10,), dtype="float32") = R.builtin.alloc_tensor(R.shape([10]), R.dtype("float32"), R.prim_value(0), R.str("global"))
gv: R.Tensor((10,), dtype="float32") = alloc4
tensor_1dim = R.call_tir(cls.reshape_62, (gv,), out_sinfo=R.Tensor((10,), dtype="float32"))
pad_tensor = R.call_tir(cls.zeros, R.tuple(), out_sinfo=R.Tensor((3211254,), dtype="float32"))
temp = R.call_tir(cls.concatenate, (tensor_1dim, pad_tensor), out_sinfo=R.Tensor((3211264,), dtype="float32"))
para0 = R.call_tir(cls.reshape_63, (temp,), out_sinfo=R.Tensor((4, 64, 112, 112), dtype="float32"))
tensor_1dim_1 = R.call_tir(cls.reshape_62, (gv,), out_sinfo=R.Tensor((10,), dtype="float32"))
pad_tensor_1 = R.call_tir(cls.zeros1, R.tuple(), out_sinfo=R.Tensor((12534,), dtype="float32"))
temp_1 = R.call_tir(cls.concatenate1, (tensor_1dim_1, pad_tensor_1), out_sinfo=R.Tensor((12544,), dtype="float32"))
para1 = R.call_tir(cls.reshape_64, (temp_1,), out_sinfo=R.Tensor((112, 112), dtype="float32"))
tensor_1dim_2 = R.call_tir(cls.reshape_62, (gv,), out_sinfo=R.Tensor((10,), dtype="float32"))
pad_tensor_2 = R.call_tir(cls.zeros1, R.tuple(), out_sinfo=R.Tensor((12534,), dtype="float32"))
temp_2 = R.call_tir(cls.concatenate1, (tensor_1dim_2, pad_tensor_2), out_sinfo=R.Tensor((12544,), dtype="float32"))
para2 = R.call_tir(cls.reshape_64, (temp_2,), out_sinfo=R.Tensor((112, 112), dtype="float32"))
res: R.Tensor((4, 64, 112, 112), dtype="float32") = cls.main_0_9_0_2(para0, para1, para2)
return res
@R.function
def main_0_9_0_2(x: R.Tensor((4, 64, 112, 112), dtype="float32"), gamma: R.Tensor((112, 112), dtype="float32"), beta: R.Tensor((112, 112), dtype="float32")) -> R.Tensor((4, 64, 112, 112), dtype="float32"):
cls = Module
with R.dataflow():
ln = R.call_tir(cls.layer_norm_2, (x, gamma, beta), out_sinfo=R.Tensor((4, 64, 112, 112), dtype="float32"))
R.output(ln)
return ln
mod = Module
mod = relax.transform.AnnotateTIROpPattern()(mod)
mod = relax.transform.FuseOps()(mod)
mod = relax.transform.FuseTIR()(mod)