tvm
tvm copied to clipboard
[Bug] [Relax] VMBuiltinLower expects bound value to be a ShapeExpr
Actual behavior
Traceback (most recent call last):
File "test_simple.py", line 46, in <module>
ex = relax.build(mod, target='llvm') # crash here!
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm/python/tvm/relax/vm_build.py", line 335, in build
mod = pipeline(mod)
^^^^^^^^^^^^^
File "/software/tvm/python/tvm/ir/transform.py", line 265, in __call__
return _ffi_transform_api.RunPass(self, mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
raise_last_ffi_error()
File "/software/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
raise py_err
File "/software/tvm/python/tvm/relax/pipeline.py", line 101, in _pipeline
mod = seq(mod)
^^^^^^^^
File "/software/tvm/python/tvm/ir/transform.py", line 265, in __call__
return _ffi_transform_api.RunPass(self, mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
raise_last_ffi_error()
File "/software/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
raise py_err
tvm._ffi.base.TVMError: Traceback (most recent call last):
28: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
27: tvm::transform::Pass::operator()(tvm::IRModule) const
26: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
25: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
24: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
23: tvm::relax::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
22: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_1
21: tvm::runtime::TypedPackedFunc<tvm::relax::Function (tvm::relax::Function, tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relax::transform::VMBuiltinLower()::{lambda(tvm::relax::Function, tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relax::transform::VMBuiltinLower()::{lambda(tvm::relax::Function, tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const
20: tvm::relax::VMBuiltinLower(tvm::RelayExpr const&)
19: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
18: tvm::relax::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
17: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7runtime9ObjectRef
16: tvm::relax::ExprMutator::VisitExpr_(tvm::relax::FunctionNode const*)
15: tvm::relax::ExprMutator::VisitWithNewScope(tvm::RelayExpr const&, tvm::runtime::Optional<tvm::runtime::Array<tvm::relax::Var, void> >)
14: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
13: tvm::relax::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
12: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7runtime9ObjectRef
11: tvm::relax::ExprMutator::VisitExpr_(tvm::relax::SeqExprNode const*)
10: tvm::relax::ExprMutator::VisitBindingBlock(tvm::relax::BindingBlock const&)
9: tvm::relax::ExprMutator::VisitBindingBlock_(tvm::relax::BindingBlockNode const*)
8: tvm::relax::ExprMutator::VisitBinding(tvm::relax::Binding const&)
7: tvm::relax::ExprMutator::VisitBinding_(tvm::relax::VarBindingNode const*)
6: _ZZN3tvm5relax11ExprMutator22InitVisitBindingVTabl
5: tvm::relax::ExprMutator::VisitBinding_(tvm::relax::VarBindingNode const*, tvm::relax::CallNode const*)
4: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
3: tvm::relax::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
2: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7runtime9ObjectRef
1: tvm::relax::VMBuiltinLowerMutator::VisitExpr_(tvm::relax::CallNode const*)
0: tvm::relax::VMBuiltinLowerMutator::Reshape(tvm::relax::Call const&)
File "/software/tvm/src/relax/backend/vm/vm_builtin_lower.cc", line 120
TVMError: Check failed: (bound_val->IsInstance<ShapeExprNode>()) is false: VMBuiltinLower expects bound value to be a ShapeExpr
Environment
Steps to reproduce
import tvm
from tvm import relax
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def add(c0: T.Buffer((T.int64(2),), "int64"), c0_1: T.Buffer((T.int64(2),), "int64"), T_add: T.Buffer((T.int64(2),), "int64")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(2)):
with T.block("T_add"):
v_ax0 = T.axis.spatial(T.int64(2), ax0)
T.reads(c0[v_ax0], c0_1[v_ax0])
T.writes(T_add[v_ax0])
T_add[v_ax0] = c0[v_ax0] + c0_1[v_ax0]
@T.prim_func(private=True)
def multiply(lv0: T.Buffer((T.int64(2),), "int64"), c1: T.Buffer((T.int64(2),), "int64"), T_multiply: T.Buffer((T.int64(2),), "int64")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(2)):
with T.block("T_multiply"):
v_ax0 = T.axis.spatial(T.int64(2), ax0)
T.reads(lv0[v_ax0], c1[v_ax0])
T.writes(T_multiply[v_ax0])
T_multiply[v_ax0] = lv0[v_ax0] * c1[v_ax0]
@R.function
def main(data: R.Tensor((256,), dtype="float32"), c0: R.Tensor((2,), dtype="int64"), c1: R.Tensor((2,), dtype="int64")) -> R.Tensor(dtype="float32", ndim=2):
cls = Module
with R.dataflow():
lv0 = R.call_tir(cls.add, (c0, c0), out_sinfo=R.Tensor((2,), dtype="int64"))
target_shape = R.call_tir(cls.multiply, (lv0, c1), out_sinfo=R.Tensor((2,), dtype="int64"))
lv2: R.Shape(ndim=2) = R.tensor_to_shape(target_shape)
gv: R.Tensor(dtype="float32", ndim=2) = R.reshape(data, lv2)
R.output(gv)
return gv
mod = Module
mod = tvm.relax.transform.LegalizeOps()(mod)
mod = relax.transform.FuseTIR()(mod)
mod = relax.transform.LambdaLift()(mod)
ex = relax.build(mod, target='llvm') # crash here!
cc @Lunderberg @junrushao
First observation, the calls to LegalizeOps, FuseTIR, and LambdaLift are not required to trigger this bug.
It looks like VMBuiltinLower is being over-zealous in its type-checking by requiring an in-line R.shape(...). So long as expression has ShapeStructInfo, it should be valid, regardless of whether it is an in-line R.shape or something returned by a function call.
@Cookiee235 The fix in #17218 resolves the test case on my end. Can you verify on your side?
@Lunderberg The test case also runs correctly under the given patch on my side! Thank you!