tvm
tvm copied to clipboard
[Bug][Relax] Shape Mismatch for function argument
Actual behavior
Traceback (most recent call last):
File "/share_container/optfuzz/res/res_ut/res_executions/30_test.py", line 50, in <module>
ex = relax.build(mod, target='llvm')
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm/python/tvm/relax/vm_build.py", line 335, in build
mod = pipeline(mod)
^^^^^^^^^^^^^
File "/software/tvm/python/tvm/ir/transform.py", line 270, in __call__
return _ffi_transform_api.RunPass(self, mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
raise_last_ffi_error()
File "/software/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
raise py_err
File "/software/tvm/python/tvm/relax/pipeline.py", line 101, in _pipeline
mod = seq(mod)
^^^^^^^^
File "/software/tvm/python/tvm/ir/transform.py", line 270, in __call__
return _ffi_transform_api.RunPass(self, mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
raise_last_ffi_error()
File "/software/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
raise py_err
tvm._ffi.base.TVMError: Traceback (most recent call last):
38: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
37: tvm::transform::Pass::operator()(tvm::IRModule) const
36: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
35: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
34: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
33: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
32: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_1
31: tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relax::transform::CallTIRRewrite()::{lambda(tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relax::transform::CallTIRRewrite()::{lambda(tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const
30: tvm::relax::CallTIRMutator::Run()
29: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
28: tvm::relax::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
27: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7runtime9ObjectRef
26: tvm::relax::ExprMutator::VisitExpr_(tvm::relax::FunctionNode const*)
25: tvm::relax::ExprMutator::VisitWithNewScope(tvm::RelayExpr const&, tvm::runtime::Optional<tvm::runtime::Array<tvm::relax::Var, void> >)
24: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
23: tvm::relax::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
22: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7runtime9ObjectRef
21: tvm::relax::ExprMutator::VisitExpr_(tvm::relax::SeqExprNode const*)
20: tvm::relax::ExprMutator::VisitBindingBlock(tvm::relax::BindingBlock const&)
19: tvm::relax::ExprMutator::VisitBindingBlock_(tvm::relax::BindingBlockNode const*)
18: tvm::relax::ExprMutator::VisitBinding(tvm::relax::Binding const&)
17: tvm::relax::ExprMutator::VisitBinding_(tvm::relax::VarBindingNode const*)
16: _ZZN3tvm5relax11ExprMutator22InitVisitBindingVTabl
15: tvm::relax::ExprMutator::VisitBinding_(tvm::relax::VarBindingNode const*, tvm::relax::CallNode const*)
14: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
13: tvm::relax::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
12: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7runtime9ObjectRef
11: tvm::relax::CallTIRMutator::VisitExpr_(tvm::relax::CallNode const*)
10: tvm::relax::BlockBuilderImpl::Emit(tvm::RelayExpr, tvm::runtime::String)
9: tvm::relax::BlockBuilderImpl::Emit(tvm::RelayExpr, bool, tvm::runtime::String)
8: tvm::relax::Normalizer::Normalize(tvm::RelayExpr const&)
7: tvm::relax::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
6: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7runtime9ObjectRef
5: non-virtual thunk to tvm::relax::Normalizer::VisitExpr_(tvm::relax::CallNode const*)
4: tvm::relax::Normalizer::VisitExpr_(tvm::relax::CallNode const*)
3: tvm::relax::Normalizer::InferStructInfo(tvm::relax::Call const&)
2: tvm::relax::DeriveCallRetStructInfo(tvm::relax::FuncStructInfo const&, tvm::relax::Call const&, tvm::relax::BlockBuilder const&, tvm::arith::Analyzer*)
1: tvm::relax::CallRetStructInfoDeriver::Derive(tvm::relax::FuncStructInfo const&, tvm::relax::Call const&, tvm::relax::BlockBuilder const&)
0: tvm::relax::BlockBuilderImpl::ReportFatal(tvm::Diagnostic const&)
File "/software/tvm/src/relax/ir/block_builder.cc", line 159
TVMError: Argument 0 type mismatch: expected R.Tensor((64, 64, 56, 56), dtype="float32"), given R.Tensor((1, 64, 56, 56), dtype="float32")
Steps to reproduce
import tvm
from tvm import relax
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def conv2d(data: T.Buffer((T.int64(1), T.int64(64), T.int64(56), T.int64(56)), "float32"), weight1: T.Buffer((T.int64(64), T.int64(64), T.int64(3), T.int64(3)), "float32"), conv2d_nchw: T.Buffer((T.int64(1), T.int64(64), T.int64(56), T.int64(56)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
pad_temp = T.alloc_buffer((T.int64(1), T.int64(64), T.int64(58), T.int64(58)))
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(64), T.int64(58), T.int64(58)):
with T.block("pad_temp"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(data[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1)])
T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3])
pad_temp[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(T.int64(1) <= v_i2 and v_i2 < T.int64(57) and T.int64(1) <= v_i3 and v_i3 < T.int64(57), data[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1)], T.float32(0))
for nn, ff, yy, xx, rc, ry, rx in T.grid(T.int64(1), T.int64(64), T.int64(56), T.int64(56), T.int64(64), T.int64(3), T.int64(3)):
with T.block("conv2d_nchw"):
v_nn, v_ff, v_yy, v_xx, v_rc, v_ry, v_rx = T.axis.remap("SSSSRRR", [nn, ff, yy, xx, rc, ry, rx])
T.reads(pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx], weight1[v_ff, v_rc, v_ry, v_rx])
T.writes(conv2d_nchw[v_nn, v_ff, v_yy, v_xx])
with T.init():
conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = T.float32(0)
conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = conv2d_nchw[v_nn, v_ff, v_yy, v_xx] + pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx] * weight1[v_ff, v_rc, v_ry, v_rx]
@T.prim_func
def relu(data: T.Buffer((64, 64, 56, 56), "float32"), out: T.Buffer((64, 64, 56, 56), "float32")):
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(64, 64, 56, 56):
with T.block("root"):
i, j, k, l = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(data[i, j, k, l])
T.writes(out[i, j, k, l])
out[i, j, k, l] = T.max(data[i, j, k, l], T.float32(0))
@R.function
def main(data: R.Tensor((1, 64, 56, 56), dtype="float32"), weight1: R.Tensor((64, 64, 3, 3), dtype="float32")) -> R.Tensor((64, 64, 56, 56), dtype="float32"):
cls = Module
with R.dataflow():
conv1 = R.call_tir(cls.conv2d, (data, weight1), out_sinfo=R.Tensor((1, 64, 56, 56), dtype="float32"))
relu1 = R.call_tir(cls.relu, (conv1,), out_sinfo=R.Tensor((64, 64, 56, 56), dtype="float32"))
R.output(relu1)
return relu1
mod = Module
mod.show()
ex = relax.build(mod, target='llvm')
The given Relax IR passed the IR validity checking but threw a crash when I built it. Could you help me review it? Thanks a lot!
CC @Lunderberg @junrushao
Hi @Cookiee235 ,
The error is caused by a mismatch between the output shape of conv2d and the input shape of relu, which are (1, 64, 56, 56) and (64, 64, 56, 56), respectively. I changed the shape of relu from (64, 64, 56, 56) to (1, 64, 56, 56) and it is built successfully.
@xhmelon Thanks for your investigation. Indeed, the Realx IR is invalid and the crash message also gives the correct warning. However, the above Relax IR passes the verify_well_formed validation and lets us mistakenly consider the Relax IR valid! It will be better if we catch the exception early (i.e., crash in the mod = Module statement)!