tvm
tvm copied to clipboard
[Bug] InternalError: non-normalized expression R.memory.kill_tensor(metadata["relax.expr.Constant"][0]
trafficstars
Actual behavior
Traceback (most recent call last):
File "/share_container/optfuzz/res/bugs/reduced/complete/328_test.py", line 162, in <module>
ex = relax.build(mod, target='llvm')
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm/python/tvm/relax/vm_build.py", line 335, in build
mod = pipeline(mod)
^^^^^^^^^^^^^
File "/software/tvm/python/tvm/ir/transform.py", line 270, in __call__
return _ffi_transform_api.RunPass(self, mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
raise_last_ffi_error()
File "/software/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
raise py_err
File "/software/tvm/python/tvm/relax/pipeline.py", line 101, in _pipeline
mod = seq(mod)
^^^^^^^^
File "/software/tvm/python/tvm/ir/transform.py", line 270, in __call__
return _ffi_transform_api.RunPass(self, mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
raise_last_ffi_error()
File "/software/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
raise py_err
tvm.error.InternalError: Traceback (most recent call last):
26: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
25: tvm::transform::Pass::operator()(tvm::IRModule) const
24: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
23: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
22: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
21: tvm::relax::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
20: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_1
19: tvm::runtime::TypedPackedFunc<tvm::relax::Function (tvm::relax::Function, tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relax::transform::KillAfterLastUse()::{lambda(tvm::relax::Function, tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relax::transform::KillAfterLastUse()::{lambda(tvm::relax::Function, tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const
18: tvm::relax::KillAfterLastUse(tvm::RelayExpr)
17: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
16: tvm::relax::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
15: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7runtime9ObjectRef
14: tvm::relax::KillInserter::VisitExpr_(tvm::relax::FunctionNode const*)
13: tvm::relax::CollectLastUsage::Collect(tvm::RelayExpr const&)
12: tvm::relax::ExprVisitor::VisitExpr(tvm::RelayExpr const&)
11: _ZZN3tvm5relax11ExprFunctorIFvRKNS_9RelayExp
10: tvm::relax::ExprVisitor::VisitExpr_(tvm::relax::FunctionNode const*)
9: tvm::relax::ExprVisitor::VisitExpr(tvm::RelayExpr const&)
8: _ZZN3tvm5relax11ExprFunctorIFvRKNS_9RelayExp
7: tvm::relax::ExprVisitor::VisitExpr_(tvm::relax::SeqExprNode const*)
6: tvm::relax::ExprVisitor::VisitBindingBlock(tvm::relax::BindingBlock const&)
5: tvm::relax::ExprVisitor::VisitBindingBlock_(tvm::relax::BindingBlockNode const*)
4: tvm::relax::CollectLastUsage::VisitBinding(tvm::relax::Binding const&)
3: tvm::relax::ExprVisitor::VisitBinding(tvm::relax::Binding const&)
2: tvm::relax::ExprVisitor::VisitBinding_(tvm::relax::VarBindingNode const*)
1: _ZZN3tvm5relax11ExprVisitor22InitVisitBindingVTabl
0: tvm::relax::CollectLastUsage::VisitBinding_(tvm::relax::VarBindingNode const*, tvm::relax::CallNode const*)
File "/software/tvm/src/relax/transform/kill_after_last_use.cc", line 171
InternalError: Check failed: (killed_object) is false: Internal error: non-normalized expression R.memory.kill_tensor(metadata["relax.expr.Constant"][0])
Steps to reproduce
reproducible script
import tvm
from tvm import relax
metadata = tvm.ir.load_json("""{
\"root\": 1,
\"nodes\": [
{
\"type_key\": \"\"
},
{
\"type_key\": \"Map\",
\"keys\": [
\"relax.expr.Constant\"
],
\"data\": [2]
},
{
\"type_key\": \"Array\",
\"data\": [3]
},
{
\"type_key\": \"relax.expr.Constant\",
\"attrs\": {
\"_checked_type_\": \"11\",
\"data\": \"0\",
\"span\": \"0\",
\"struct_info_\": \"4\"
}
},
{
\"type_key\": \"relax.TensorStructInfo\",
\"attrs\": {
\"dtype\": \"float32\",
\"ndim\": \"2\",
\"shape\": \"5\",
\"span\": \"0\",
\"vdevice\": \"0\"
}
},
{
\"type_key\": \"relax.expr.ShapeExpr\",
\"attrs\": {
\"_checked_type_\": \"10\",
\"span\": \"0\",
\"struct_info_\": \"9\",
\"values\": \"6\"
}
},
{
\"type_key\": \"Array\",
\"data\": [7, 8]
},
{
\"type_key\": \"IntImm\",
\"attrs\": {
\"dtype\": \"int64\",
\"span\": \"0\",
\"value\": \"16\"
}
},
{
\"type_key\": \"IntImm\",
\"attrs\": {
\"dtype\": \"int64\",
\"span\": \"0\",
\"value\": \"16\"
}
},
{
\"type_key\": \"relax.ShapeStructInfo\",
\"attrs\": {
\"ndim\": \"2\",
\"span\": \"0\",
\"values\": \"6\"
}
},
{
\"type_key\": \"relax.ShapeType\",
\"attrs\": {
\"ndim\": \"2\",
\"span\": \"0\"
}
},
{
\"type_key\": \"relax.DynTensorType\",
\"attrs\": {
\"dtype\": \"float32\",
\"ndim\": \"2\",
\"span\": \"0\"
}
}
],
\"b64ndarrays\": [
\"P6G0lvBAXt0AAAAAAAAAAAEAAAAAAAAAAgAAAAIgAQAQAAAAAAAAABAAAAAAAAAAAAQAAAAAAAAAAAAAAACAPwAAAEAAAEBAAACAQAAAoEAAAMBAAADgQAAAAEEAABBBAAAgQQAAMEEAAEBBAABQQQAAYEEAAHBBAACAQQAAiEEAAJBBAACYQQAAoEEAAKhBAACwQQAAuEEAAMBBAADIQQAA0EEAANhBAADgQQAA6EEAAPBBAAD4QQAAAEIAAARCAAAIQgAADEIAABBCAAAUQgAAGEIAABxCAAAgQgAAJEIAAChCAAAsQgAAMEIAADRCAAA4QgAAPEIAAEBCAABEQgAASEIAAExCAABQQgAAVEIAAFhCAABcQgAAYEIAAGRCAABoQgAAbEIAAHBCAAB0QgAAeEIAAHxCAACAQgAAgkIAAIRCAACGQgAAiEIAAIpCAACMQgAAjkIAAJBCAACSQgAAlEIAAJZCAACYQgAAmkIAAJxCAACeQgAAoEIAAKJCAACkQgAApkIAAKhCAACqQgAArEIAAK5CAACwQgAAskIAALRCAAC2QgAAuEIAALpCAAC8QgAAvkIAAMBCAADCQgAAxEIAAMZCAADIQgAAykIAAMxCAADOQgAA0EIAANJCAADUQgAA1kIAANhCAADaQgAA3EIAAN5CAADgQgAA4kIAAORCAADmQgAA6EIAAOpCAADsQgAA7kIAAPBCAADyQgAA9EIAAPZCAAD4QgAA+kIAAPxCAAD+QgAAAEMAAAFDAAACQwAAA0MAAARDAAAFQwAABkMAAAdDAAAIQwAACUMAAApDAAALQwAADEMAAA1DAAAOQwAAD0MAABBDAAARQwAAEkMAABNDAAAUQwAAFUMAABZDAAAXQwAAGEMAABlDAAAaQwAAG0MAABxDAAAdQwAAHkMAAB9DAAAgQwAAIUMAACJDAAAjQwAAJEMAACVDAAAmQwAAJ0MAAChDAAApQwAAKkMAACtDAAAsQwAALUMAAC5DAAAvQwAAMEMAADFDAAAyQwAAM0MAADRDAAA1QwAANkMAADdDAAA4QwAAOUMAADpDAAA7QwAAPEMAAD1DAAA+QwAAP0MAAEBDAABBQwAAQkMAAENDAABEQwAARUMAAEZDAABHQwAASEMAAElDAABKQwAAS0MAAExDAABNQwAATkMAAE9DAABQQwAAUUMAAFJDAABTQwAAVEMAAFVDAABWQwAAV0MAAFhDAABZQwAAWkMAAFtDAABcQwAAXUMAAF5DAABfQwAAYEMAAGFDAABiQwAAY0MAAGRDAABlQwAAZkMAAGdDAABoQwAAaUMAAGpDAABrQwAAbEMAAG1DAABuQwAAb0MAAHBDAABxQwAAckMAAHNDAAB0QwAAdUMAAHZDAAB3QwAAeEMAAHlDAAB6QwAAe0MAAHxDAAB9QwAAfkMAAH9D\"
],
\"attrs\": {\"tvm_version\": \"0.17.dev0\"}
}""")
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func
def add(rxplaceholder: T.Buffer((T.int64(8),), "float32"), rxplaceholder_1: T.Buffer((), "float32"), T_add: T.Buffer((T.int64(8),), "float32")):
T.evaluate(0)
@T.prim_func(private=True)
def add_2(A: T.Buffer((T.int64(16), T.int64(16)), "float32"), B: T.Buffer((T.int64(16), T.int64(16)), "float32"), T_add: T.Buffer((T.int64(16), T.int64(16)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(16), T.int64(16)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[v_ax0, v_ax1]
@T.prim_func(private=True)
def cast1(gv: T.Buffer((T.int64(16), T.int64(16)), "float32"), compute: T.Buffer((T.int64(16), T.int64(16)), "float16")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i0, i1 in T.grid(T.int64(16), T.int64(16)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(gv[v_i0, v_i1])
T.writes(compute[v_i0, v_i1])
compute[v_i0, v_i1] = T.Cast("float16", gv[v_i0, v_i1])
@T.prim_func
def exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), compute: T.Buffer((T.int64(2), T.int64(4)), "float32")):
T.evaluate(0)
@T.prim_func
def log(rxplaceholder: T.Buffer((T.int64(10),), "float32"), compute: T.Buffer((T.int64(10),), "float32")):
T.evaluate(0)
@T.prim_func
def pad(rxplaceholder: T.Buffer((T.int64(8),), "float32"), PadInput: T.Buffer((T.int64(10),), "float32")):
T.evaluate(0)
@T.prim_func
def relu(rxplaceholder: T.Buffer((T.int64(8),), "float32"), compute: T.Buffer((T.int64(8),), "float32")):
T.evaluate(0)
@T.prim_func
def reshape(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), T_reshape: T.Buffer((T.int64(8),), "float32")):
T.evaluate(0)
@R.function
def main() -> R.Tensor((16, 16), dtype="float16"):
cls = Module
with R.dataflow():
gv = R.call_tir(cls.add_2, (metadata["relax.expr.Constant"][0], metadata["relax.expr.Constant"][0]), out_sinfo=R.Tensor((16, 16), dtype="float32"))
gv_1 = R.call_tir(cls.cast1, (gv,), out_sinfo=R.Tensor((16, 16), dtype="float16"))
R.output(gv_1)
return gv_1
mod = Module
seq = tvm.transform.Sequential([relax.transform.KillAfterLastUse(), relax.transform.FoldConstant()]) # only this sequence can trigger the bug
mod = seq(mod)
ex = relax.build(mod, target='llvm')
CC @Lunderberg