tvm icon indicating copy to clipboard operation
tvm copied to clipboard

[Bug] [Relax] VMBuiltinLower expects bound value to be a ShapeExpr

Open Cookiee235 opened this issue 1 year ago • 3 comments

Actual behavior

Traceback (most recent call last):
  File "test_simple.py", line 46, in <module>
    ex = relax.build(mod, target='llvm') # crash here!
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm/python/tvm/relax/vm_build.py", line 335, in build
    mod = pipeline(mod)
          ^^^^^^^^^^^^^
  File "/software/tvm/python/tvm/ir/transform.py", line 265, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/software/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
  File "/software/tvm/python/tvm/relax/pipeline.py", line 101, in _pipeline
    mod = seq(mod)
          ^^^^^^^^
  File "/software/tvm/python/tvm/ir/transform.py", line 265, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/software/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
tvm._ffi.base.TVMError: Traceback (most recent call last):
  28: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  27: tvm::transform::Pass::operator()(tvm::IRModule) const
  26: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  25: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  24: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  23: tvm::relax::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  22: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_1
  21: tvm::runtime::TypedPackedFunc<tvm::relax::Function (tvm::relax::Function, tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relax::transform::VMBuiltinLower()::{lambda(tvm::relax::Function, tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relax::transform::VMBuiltinLower()::{lambda(tvm::relax::Function, tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const
  20: tvm::relax::VMBuiltinLower(tvm::RelayExpr const&)
  19: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  18: tvm::relax::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  17: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7runtime9ObjectRef
  16: tvm::relax::ExprMutator::VisitExpr_(tvm::relax::FunctionNode const*)
  15: tvm::relax::ExprMutator::VisitWithNewScope(tvm::RelayExpr const&, tvm::runtime::Optional<tvm::runtime::Array<tvm::relax::Var, void> >)
  14: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  13: tvm::relax::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  12: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7runtime9ObjectRef
  11: tvm::relax::ExprMutator::VisitExpr_(tvm::relax::SeqExprNode const*)
  10: tvm::relax::ExprMutator::VisitBindingBlock(tvm::relax::BindingBlock const&)
  9: tvm::relax::ExprMutator::VisitBindingBlock_(tvm::relax::BindingBlockNode const*)
  8: tvm::relax::ExprMutator::VisitBinding(tvm::relax::Binding const&)
  7: tvm::relax::ExprMutator::VisitBinding_(tvm::relax::VarBindingNode const*)
  6: _ZZN3tvm5relax11ExprMutator22InitVisitBindingVTabl
  5: tvm::relax::ExprMutator::VisitBinding_(tvm::relax::VarBindingNode const*, tvm::relax::CallNode const*)
  4: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  3: tvm::relax::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  2: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7runtime9ObjectRef
  1: tvm::relax::VMBuiltinLowerMutator::VisitExpr_(tvm::relax::CallNode const*)
  0: tvm::relax::VMBuiltinLowerMutator::Reshape(tvm::relax::Call const&)
  File "/software/tvm/src/relax/backend/vm/vm_builtin_lower.cc", line 120
TVMError: Check failed: (bound_val->IsInstance<ShapeExprNode>()) is false: VMBuiltinLower expects bound value to be a ShapeExpr

Environment

Steps to reproduce

import tvm
from tvm import relax
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def add(c0: T.Buffer((T.int64(2),), "int64"), c0_1: T.Buffer((T.int64(2),), "int64"), T_add: T.Buffer((T.int64(2),), "int64")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(2)):
            with T.block("T_add"):
                v_ax0 = T.axis.spatial(T.int64(2), ax0)
                T.reads(c0[v_ax0], c0_1[v_ax0])
                T.writes(T_add[v_ax0])
                T_add[v_ax0] = c0[v_ax0] + c0_1[v_ax0]

    @T.prim_func(private=True)
    def multiply(lv0: T.Buffer((T.int64(2),), "int64"), c1: T.Buffer((T.int64(2),), "int64"), T_multiply: T.Buffer((T.int64(2),), "int64")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(2)):
            with T.block("T_multiply"):
                v_ax0 = T.axis.spatial(T.int64(2), ax0)
                T.reads(lv0[v_ax0], c1[v_ax0])
                T.writes(T_multiply[v_ax0])
                T_multiply[v_ax0] = lv0[v_ax0] * c1[v_ax0]

    @R.function
    def main(data: R.Tensor((256,), dtype="float32"), c0: R.Tensor((2,), dtype="int64"), c1: R.Tensor((2,), dtype="int64")) -> R.Tensor(dtype="float32", ndim=2):
        cls = Module
        with R.dataflow():
            lv0 = R.call_tir(cls.add, (c0, c0), out_sinfo=R.Tensor((2,), dtype="int64"))
            target_shape = R.call_tir(cls.multiply, (lv0, c1), out_sinfo=R.Tensor((2,), dtype="int64"))
            lv2: R.Shape(ndim=2) = R.tensor_to_shape(target_shape)
            gv: R.Tensor(dtype="float32", ndim=2) = R.reshape(data, lv2)
            R.output(gv)
        return gv

mod = Module
mod = tvm.relax.transform.LegalizeOps()(mod)
mod = relax.transform.FuseTIR()(mod)
mod = relax.transform.LambdaLift()(mod)
ex = relax.build(mod, target='llvm') # crash here!

cc @Lunderberg @junrushao

Cookiee235 avatar Jul 30 '24 14:07 Cookiee235

First observation, the calls to LegalizeOps, FuseTIR, and LambdaLift are not required to trigger this bug.

It looks like VMBuiltinLower is being over-zealous in its type-checking by requiring an in-line R.shape(...). So long as expression has ShapeStructInfo, it should be valid, regardless of whether it is an in-line R.shape or something returned by a function call.

Lunderberg avatar Jul 30 '24 15:07 Lunderberg

@Cookiee235 The fix in #17218 resolves the test case on my end. Can you verify on your side?

Lunderberg avatar Jul 30 '24 15:07 Lunderberg

@Lunderberg The test case also runs correctly under the given patch on my side! Thank you!

Cookiee235 avatar Jul 30 '24 15:07 Cookiee235