tvm
tvm copied to clipboard
[Bug] [Relax] [LambdaLift] Argument type mismatch: expected R.Tensor, given R.Object
trafficstars
Actual behavior
Traceback (most recent call last):
File "/share_container/optfuzz/res/bugs/llm.py", line 35, in <module>
ex = relax.build(mod, target='llvm')
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm-latest/python/tvm/relax/vm_build.py", line 335, in build
mod = pipeline(mod)
^^^^^^^^^^^^^
File "/software/tvm-latest/python/tvm/ir/transform.py", line 238, in __call__
return _ffi_transform_api.RunPass(self, mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm-latest/python/tvm/_ffi/_ctypes/packed_func.py", line 245, in __call__
raise_last_ffi_error()
File "/software/tvm-latest/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
raise py_err
File "/software/tvm-latest/python/tvm/relax/pipeline.py", line 101, in _pipeline
mod = seq(mod)
^^^^^^^^
File "/software/tvm-latest/python/tvm/ir/transform.py", line 238, in __call__
return _ffi_transform_api.RunPass(self, mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm-latest/python/tvm/_ffi/_ctypes/packed_func.py", line 245, in __call__
raise_last_ffi_error()
File "/software/tvm-latest/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
raise py_err
tvm._ffi.base.TVMError: Traceback (most recent call last):
31: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
30: tvm::transform::Pass::operator()(tvm::IRModule) const
29: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
28: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
27: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
26: tvm::relax::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
25: _ZN3tvm7runtime13PackedFuncObj
24: tvm::runtime::TypedPackedFunc<tvm::relax::Function (tvm::relax::Function, tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relax::transform::LowerRuntimeBuiltin()::{lambda(tvm::relax::Function, tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relax::transform::LowerRuntimeBuiltin()::{lambda(tvm::relax::Function, tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const
23: tvm::relax::LowerRuntimeBuiltin(tvm::RelayExpr const&)
22: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
21: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
20: tvm::relax::ExprMutator::VisitExpr_(tvm::relax::FunctionNode const*)
19: tvm::relax::ExprMutator::VisitWithNewScope(tvm::RelayExpr const&, tvm::runtime::Optional<tvm::runtime::Array<tvm::relax::Var, void> >)
18: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
17: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
16: tvm::relax::ExprMutator::VisitExpr_(tvm::relax::SeqExprNode const*)
15: tvm::relax::ExprMutator::VisitBindingBlock(tvm::relax::BindingBlock const&)
14: tvm::relax::ExprMutator::VisitBindingBlock_(tvm::relax::BindingBlockNode const*)
13: tvm::relax::ExprMutator::VisitBinding(tvm::relax::Binding const&)
12: tvm::relax::ExprMutator::VisitBinding_(tvm::relax::VarBindingNode const*)
11: tvm::relax::ExprMutator::VisitBinding_(tvm::relax::VarBindingNode const*, tvm::GlobalVarNode const*)
10: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
9: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
8: tvm::relax::LowerRuntimeBuiltinMutator::VisitExpr_(tvm::relax::CallNode const*)
7: tvm::relax::Normalizer::Normalize(tvm::RelayExpr const&)
6: tvm::relax::Normalizer::VisitExpr(tvm::RelayExpr const&)
5: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
4: tvm::relax::Normalizer::VisitExpr_(tvm::relax::CallNode const*)
3: tvm::relax::Normalizer::InferStructInfo(tvm::relax::Call const&)
2: tvm::relax::DeriveCallRetStructInfo(tvm::relax::FuncStructInfo const&, tvm::relax::Call const&, tvm::relax::BlockBuilder const&, tvm::arith::Analyzer*)
1: tvm::relax::CallRetStructInfoDeriver::Derive(tvm::relax::FuncStructInfo const&, tvm::relax::Call const&, tvm::relax::BlockBuilder const&)
0: tvm::relax::BlockBuilderImpl::ReportFatal(tvm::Diagnostic const&)
File "/software/tvm-latest/src/relax/ir/block_builder.cc", line 158
TVMError: Argument 0 type mismatch: expected R.Tensor((2, 3), dtype="float32"), given R.Object
Steps to reproduce
import tvm
from tvm import relax
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def main(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"):
cls = Module
@R.function
def outer_func(c1: R.Tensor((2, 3), dtype="float32")) -> R.Callable((R.Tensor((2, 3), dtype="float32"),), R.Tensor((2, 3), dtype="float32"), True):
@R.function
def inner_func(x1: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"):
s = R.add(x1,c1)
return s
return inner_func
in_call: R.Callable((R.Tensor((2, 3), dtype="float32"),), R.Tensor((2, 3), dtype="float32"), True) = outer_func(x)
res: R.Tensor((2, 3), dtype="float32") = in_call(y)
res_1 = R.add(res,x)
return res_1
mod = Module
mod = relax.transform.LambdaLift()(mod)
mod.show()
with tvm.transform.PassContext(opt_level=4):
ex = relax.build(mod, target='llvm')
vm = relax.VirtualMachine(ex, tvm.cpu())
CC @Lunderberg @junrushao