tvm
tvm copied to clipboard
[Bug] [Relax] Missing IR structure checking and correction
Hi all, I set check_well_formed=True in the below Relax IR construction and can run mod.show() to show the IR successfully. It seems the Relax IR passed the legitimacy checking. However, the compilation crashed when executing ex = relax.build(mod, target='llvm'). The crash message shows that
"Argument 0 type mismatch: expected R.Tensor((16,), dtype="float32"), given R.Tuple(R.Tensor((16,), dtype="float32"))"
Based on my analysis, if we replace the code gv1 = R.call_tir(cls.relu, (x), out_sinfo=R.Tensor((1, 512, 64, 64))) (Line 26) with gv1 = R.nn.relu(x) (Line 27) or gv1 = R.call_tir(cls.relu, (x,), out_sinfo=R.Tensor((1, 512, 64, 64), dtype="float32")) (Line 28), the script can run well.
Even if the Relax IR constructor can convert gv1 = R.nn.relu(x) to full information with type based on the context, why didn't it complete the missing type for gv1 (Line 26).
To take a step back, if the Relax IR constructor cannot complete the missing information and we set check_cell_formed=True in the Relax IR construction, we should throw an exception early in mod = Module rather than relax.build(). Early crashes will make the code more robust.
BTW, I prefer the IR constructor can fill in missing information or correct the inconsistent constraints based on IRs' context.
Actual behavior
Traceback (most recent call last):
File "demo_simple.py", line 26, in <module>
ex = relax.build(mod, target='llvm') # crash here!
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm/python/tvm/relax/vm_build.py", line 335, in build
mod = pipeline(mod)
^^^^^^^^^^^^^
File "/software/tvm/python/tvm/ir/transform.py", line 238, in __call__
return _ffi_transform_api.RunPass(self, mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
raise_last_ffi_error()
File "/software/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
raise py_err
File "/software/tvm/python/tvm/relax/pipeline.py", line 101, in _pipeline
mod = seq(mod)
^^^^^^^^
File "/software/tvm/python/tvm/ir/transform.py", line 238, in __call__
return _ffi_transform_api.RunPass(self, mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
raise_last_ffi_error()
File "/software/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
raise py_err
tvm._ffi.base.TVMError: Traceback (most recent call last):
38: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
37: tvm::transform::Pass::operator()(tvm::IRModule) const
36: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
35: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
34: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
33: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
32: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_1
31: tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relax::transform::CallTIRRewrite()::{lambda(tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relax::transform::CallTIRRewrite()::{lambda(tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
30: tvm::relax::CallTIRMutator::Run()
29: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
28: tvm::relax::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
27: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7runtime9ObjectRef
26: tvm::relax::ExprMutator::VisitExpr_(tvm::relax::FunctionNode const*)
25: tvm::relax::ExprMutator::VisitWithNewScope(tvm::RelayExpr const&, tvm::runtime::Optional<tvm::runtime::Array<tvm::relax::Var, void> >)
24: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
23: tvm::relax::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
22: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7runtime9ObjectRef
21: tvm::relax::ExprMutator::VisitExpr_(tvm::relax::SeqExprNode const*)
20: tvm::relax::ExprMutator::VisitBindingBlock(tvm::relax::BindingBlock const&)
19: tvm::relax::ExprMutator::VisitBindingBlock_(tvm::relax::BindingBlockNode const*)
18: tvm::relax::ExprMutator::VisitBinding(tvm::relax::Binding const&)
17: tvm::relax::ExprMutator::VisitBinding_(tvm::relax::VarBindingNode const*)
16: _ZZN3tvm5relax11ExprMutator22InitVisitBindingVTabl
15: tvm::relax::ExprMutator::VisitBinding_(tvm::relax::VarBindingNode const*, tvm::relax::CallNode const*)
14: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
13: tvm::relax::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
12: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7runtime9ObjectRef
11: tvm::relax::CallTIRMutator::VisitExpr_(tvm::relax::CallNode const*)
10: tvm::relax::BlockBuilderImpl::Emit(tvm::RelayExpr, tvm::runtime::String)
9: tvm::relax::BlockBuilderImpl::Emit(tvm::RelayExpr, bool, tvm::runtime::String)
8: tvm::relax::Normalizer::Normalize(tvm::RelayExpr const&)
7: tvm::relax::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
6: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7runtime9ObjectRef
5: non-virtual thunk to tvm::relax::Normalizer::VisitExpr_(tvm::relax::CallNode const*)
4: tvm::relax::Normalizer::VisitExpr_(tvm::relax::CallNode const*)
3: tvm::relax::Normalizer::InferStructInfo(tvm::relax::Call const&)
2: tvm::relax::DeriveCallRetStructInfo(tvm::relax::FuncStructInfo const&, tvm::relax::Call const&, tvm::relax::BlockBuilder const&, tvm::arith::Analyzer*)
1: tvm::relax::CallRetStructInfoDeriver::Derive(tvm::relax::FuncStructInfo const&, tvm::relax::Call const&, tvm::relax::BlockBuilder const&)
0: tvm::relax::BlockBuilderImpl::ReportFatal(tvm::Diagnostic const&)
File "/software/tvm/src/relax/ir/block_builder.cc", line 159
TVMError: Argument 0 type mismatch: expected R.Tensor((16,), dtype="float32"), given R.Tuple(R.Tensor((16,), dtype="float32"))
Environment
- TVM: 0.17.dev0
Steps to reproduce
import tvm
from tvm import relax
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R
@I.ir_module(check_well_formed=True)
class Module:
@T.prim_func(private=True)
#def relu(A: T.Buffer((T.int64(1), T.int64(512), T.int64(64), T.int64(64)), "float32"), B: T.Buffer((T.int64(1), T.int64(512), T.int64(64), T.int64(64)), "float32")):
def relu(A: T.Buffer((T.int64(1), T.int64(512), T.int64(64), T.int64(64)), "float32"), B: T.Buffer((T.int64(1), T.int64(512), T.int64(64), T.int64(64)))):
T.func_attr({"op_pattern": 0})
# with T.block("root"):
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(512), T.int64(64), T.int64(64)):
with T.block("relu"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(A[v_i0, v_i1, v_i2, v_i3])
T.writes(B[v_i0, v_i1, v_i2, v_i3])
B[v_i0, v_i1, v_i2, v_i3] = T.max(A[v_i0, v_i1, v_i2, v_i3], T.float32(0))
@R.function
def main(x: R.Tensor((1, 512, 64, 64), dtype="float32")) -> R.Tensor((1, 512, 64, 64), dtype="float32"):
cls = Module
with R.dataflow():
gv1 = R.call_tir(cls.relu, (x), out_sinfo=R.Tensor((1, 512, 64, 64))) # crash
# gv1 = R.nn.relu(x) # run well
# gv1 = R.call_tir(cls.relu, (x,), out_sinfo=R.Tensor((1, 512, 64, 64), dtype="float32")) # run well
R.output(gv1)
return gv1
mod = Module
mod.show()
mod = relax.transform.FuseTIR()(mod)
mod = relax.transform.LambdaLift()(mod)
ex = relax.build(mod, target='llvm')
cc @Lunderberg @junrushao @tqchen
Good catch, and I think this is arising from a number of different edge cases.
-
The
StructInfoforR.call_tiris always inferred from theout_sinfo, not from the TIR function's signature. This is for historical reasons, as TIR functions only recently started holding the annotations that would allow them to perform shape inference. As a result, no errors are seen during the initial well-formed check. -
The default of
T.BufferandR.Tensoris different. If unspecified,T.Bufferdefaults to"float32"datatype, whereR.Tensordefaults toDataType::Void, which is used to represent an unknown datatype that might be inferred later in compilation. There is no equivalent in TIR, which must have a known datatype for each buffer. -
There is no rule that would infer the unknown Relax datatype from the mandatory TIR datatype. As a result, the
out_sinforemains the incompleteR.Tensor(shape), rather thanR.Tensor(shape, dtype="float32"). -
The error is raised during
CallTIRRewrite, which rewrites low-level calls from having an implied allocation for the output to having an explicit argument for the output. Here, this rewirtes theR.call_tir(cls.relu, [x], out_sinfo=R.Tensor([1,512,64,64]))intocls.relu(x, output_allocation), whereoutput_allocationhas shapeR.Tensor([1,512,64,64]). This is the first point at which the TIR function's signature is actually inspected. -
Currently, when checking whether the constraints required by a subroutine, the constraints must either pass or fail. There is no mechanism for the subroutine's constraints to be hoisted into the calling scope. Since "tensor of arbitrary element type" is not a valid argument for "tensor with float32 element type", the check fails.
I think there's a number of improvements that could be made, in order to close each of these loopholes.
-
Improved well-formed checker. If
out_sinfois explicitly stated inR.call_tir, thenIsBaseOf(inferred_sinfo, out_sinfo)must return true. -
Infer the dtype of
out_sinfoinR.call_tir. Ifout_sinfois a Tensor, or a Tuple of tensors, and one of those tensors hasDataType::Void(), normalize theout_sinfoargument to include the datatype from the PrimFunc. -
Improved struct inference for
R.call_tir. Now that PrimFuncs have a known shape for each argument, the output ofR.call_tircould be improved. For backwards compatibility, an explicitout_sinfoargument would still take precedence. However, ifout_sinfois omitted (which currently would cause an immediate error), it would instead infer the output struct info assuming that the lastlen(params) - len(args)are output parameters. -
Improved normalization in block builder. If an operator has restrictions on an argument, normalization could expose those constraints to the Relax levels, rather than only marking it as pass/fail. For example, normalization of an operator whose argument must be
DataType::Float(32), but which receivedDataType::Void(), could produce a binding ofnew_arg = R.match_cast(arg, R.Tensor(arg.struct_info.shape, "float32")), then usenew_argin its call.
I think all of these would be useful changes to make, but some would have wider impacts than others. The well-formed checks could be added with the smallest risk of breakage, but also place the greatest load on new developers. Improved normalization would provide the greatest ease-of-use, but would require the most widespread changes. @tqchen, since some of these would be much more involved changes, do you have preferences/thoughts on them?
A similar bug occurs as shown below.
Based on what I saw. The well-formed checker commonly corrects the return type and shape. However, when the type of relax function return var is R.Tuple(), the well-formed checker seems not to work.
Actual behavior
Traceback (most recent call last):
File "/share_container/optfuzz/res/bugs/res_type.py", line 82, in <module>
mod_outputs = vm['main'](input_0, input_1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
raise_last_ffi_error()
File "/software/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
raise py_err
ValueError: Traceback (most recent call last):
8: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::relax_vm::VirtualMachineImpl::_LookupFunction(tvm::runtime::String const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
7: tvm::runtime::relax_vm::VirtualMachineImpl::InvokeClosurePacked(tvm::runtime::ObjectRef const&, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
6: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::relax_vm::VirtualMachineImpl::GetClosureInternal(tvm::runtime::String const&, bool)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
5: tvm::runtime::relax_vm::VirtualMachineImpl::InvokeBytecode(long, std::vector<tvm::runtime::TVMRetValue, std::allocator<tvm::runtime::TVMRetValue> > const&)
4: tvm::runtime::relax_vm::VirtualMachineImpl::RunLoop()
3: tvm::runtime::relax_vm::VirtualMachineImpl::RunInstrCall(tvm::runtime::relax_vm::VMFrame*, tvm::runtime::relax_vm::Instruction)
2: tvm::runtime::relax_vm::VirtualMachineImpl::InvokeClosurePacked(tvm::runtime::ObjectRef const&, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
1: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16Pack
0: tvm::runtime::relax_vm::CheckTensorInfo(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
File "/software/tvm/src/runtime/relax_vm/builtin.cc", line 247
ValueError: Check failed: (DataType(ptr->dl_tensor.dtype) == dtype) is false: ErrorContext(fn=main, loc=return, annotation=R.Tuple(R.Tensor((16, 16), dtype="int32"), R.Tensor((32, 32), dtype="float32"))) expect Tensor with dtype float32 but get int32
Steps to reproduce
import tvm
from tvm import relax
import numpy as np
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def ones(T_full: T.Buffer((T.int64(16), T.int64(16)), "int32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(16), T.int64(16)):
with T.block("T_full"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads()
T.writes(T_full[v_ax0, v_ax1])
T_full[v_ax0, v_ax1] = 1
@T.prim_func(private=True)
def zeros(T_full: T.Buffer((T.int64(16), T.int64(16)), "int32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(16), T.int64(16)):
with T.block("T_full"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads()
T.writes(T_full[v_ax0, v_ax1])
T_full[v_ax0, v_ax1] = 0
@T.prim_func(private=True)
def zeros1(T_full: T.Buffer((T.int64(32), T.int64(32)), "int32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(32), T.int64(32)):
with T.block("T_full"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads()
T.writes(T_full[v_ax0, v_ax1])
T_full[v_ax0, v_ax1] = 0
@R.function(private=True)
def func() -> R.Tuple(R.Tensor((16, 16), dtype="int32"), R.Tensor((16, 16), dtype="int32"), R.Tensor((32, 32), dtype="int32")):
cls = Module
A = R.call_tir(cls.zeros, R.tuple(), out_sinfo=R.Tensor((16, 16), dtype="int32"))
B = R.call_tir(cls.ones, R.tuple(), out_sinfo=R.Tensor((16, 16), dtype="int32"))
C = R.call_tir(cls.zeros1, R.tuple(), out_sinfo=R.Tensor((32, 32), dtype="int32"))
return (A, B, C)
@R.function
def main_2() -> R.Tuple(R.Tensor, R.Tensor):
cls = Module
args: R.Tuple(R.Tensor, R.Tensor, R.Tensor) = cls.func()
gv1: R.Tensor = args[0]
gv2: R.Tensor = args[2]
return (gv1, gv2)
@R.function
def main(v3_0: R.Tensor((1, 22, 1), dtype="float16"), v6_0: R.Tensor((1, 37), dtype="float16")) -> R.Tuple(R.Tensor((16, 16), dtype="int32"), R.Tensor((32, 32), dtype="float32")): # if return value is a tuple, well_form checker cannot correct it!
R.func_attr({"num_input": 1})
cls = Module
with R.dataflow():
res: R.Tuple(R.Tensor, R.Tensor) = cls.main_2()
R.output(res)
return res
mod = Module
mod.show()
mod = tvm.relax.transform.LegalizeOps()(mod)
mod = relax.transform.FuseTIR()(mod)
mod = relax.transform.LambdaLift()(mod)
ex = relax.build(mod, target='llvm')
vm = relax.VirtualMachine(ex, tvm.cpu())
input_0 = tvm.nd.array(10 * np.random.random([1, 22, 1]).astype('float16'))
input_1 = tvm.nd.array(10 * np.random.random([1, 37]).astype('float16'))
mod_outputs = vm['main'](input_0, input_1)
Hmm. I think this is something that should be catchable by propagating the known struct info, but currently isn't caught.
- In
main_2,cls.func()returns a tuple with known dtype and static shapes, but is assigned to a variable with unknown dtype and shape. This is legal, because the set of allR.Tuple(R.Tensor, R.Tensor, R.Tensor)is a superset of the set of allR.Tuple(R.Tensor((16,16), "int32"), R.Tensor((16,16), "int32"), R.Tensor((32,32), "int32")). - In
main, even if the return type ofcls.main_2()isn't explicitly specified, it gets inferred asR.Tuple(R.Tensor, R.Tensor). - The return type from
mainmay be more specific than the body. This is intended to ensure that the return type is stable, even if an optimization prevents shape inference from reaching all the way to the end of the function, the function still has accurate annotations. However, this means that the return struct info may be more a sub-type of the body's struct info. - Whenever the return type is a sub-type of the body's struct info, a runtime assert is inserted. This is the assert that triggers the error message.
I think this is a limitation in the StructInfo inference, which should catch the IRModule as ill-formed at compile-time, rather than runtime. However, it would first require a few extra steps of StructInfo inference that aren't currently performed.
- If an expression has more specific StructInfo than the variable it is bound to, propagate from the expression to the variable.
- If the body of a function has more specific StructInfo than the current return type, propagate from the body to the return type.
- If a function has more specific StructInfo than the GlobalVar used to represent it, propagate from the function to the GlobalVar.
For the example, this would let the "int32" type returned by cls.func to be propagated through main_2, and into main. At that point, it could be recognized as an error to return "int32" in a function that is marked as returning "float32".
And one step implemented which should make it harder for these inconsistent shapes to emerge. In https://github.com/apache/tvm/pull/17216, the out_sinfo field is made optional, and is inferred from the PrimFunc signature if omitted. While it doesn't yet catch a case where the out_sinfo is inconsistent with the callee's signature, it does move in that direction.