tvm
tvm copied to clipboard
[Bug] Graph optimization model compilation error involving `Pad` operator
I am trying to compile an ONNX (graph below) model using TVM.
Of course, this is a complicated graph, but we can simplify it as below.
These two graphs are equal. When I try to compile them using TVM. The original ONNX model fails but the simplified ONNX model passes. It is very strange!
This seems to involve the Pad operator shape-checking problem.
In theory, I think TVM should have strong compatibility with the native ONNX model. However, the truth is not satisfactory.
It seems that only simplified, simple models are acceptable to TVM
Expected behavior
ONNX compilation passes
Actual behavior
onnx fail
Traceback (most recent call last):
18: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, tvm::IRModule)#1}>(tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, tvm::IRModule)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
17: tvm::transform::Pass::operator()(tvm::IRModule) const
16: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
15: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
14: _ZN3tvm7runtime13PackedFun
13: tvm::runtime::TypedPackedFunc<tvm::relay::Function (tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::DynamicToStatic()::{lambda(tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relay::transform::DynamicToStatic()::{lambda(tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
12: tvm::relay::DynamicToStatic(tvm::relay::Function, tvm::IRModule)
11: tvm::relay::DynamicToStaticMutator::PrepareInput(tvm::RelayExpr const&)
10: tvm::transform::Pass::operator()(tvm::IRModule) const
9: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
8: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
7: tvm::transform::Pass::operator()(tvm::IRModule) const
6: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
5: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
4: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1}>(tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
3: tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function)
2: tvm::relay::TypeSolver::Solve()
1: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<bool (tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>::AssignTypedLambda<bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>(bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&))::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
0: tvm::relay::PadRel(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)
File "/root/anaconda3/conda-bld/tvm-package_1701590675822/work/src/relay/op/nn/pad.cc", line 131
InternalError: Check failed: (data->shape.size() == param->pad_width.size()) is false: There should be as many pad width pairs as shape dimensions but the shape has 5 dimensions and there are 4 pad width pairs.
Environment
Operating System: Ubuntu 18 TVM:0.15 Torch: 2.1.1 ONNX: 1.15.0
Steps to reproduce
ONNX file is here: onnx.zip
Here is the script
from onnxsim import simplify
import tvm
from tvm import relay
import onnx
def compile_onnx(onnx_model, shape):
mod_from_onnx, params_onnx = relay.frontend.from_onnx(onnx_model,
shape=shape)
with tvm.transform.PassContext(opt_level=4):
executor = relay.build_module.create_executor(
'graph', mod_from_onnx, tvm.cpu(), 'llvm', params_onnx
).evaluate()
model = onnx.load('./model.onnx')
try:
compile_onnx(model, {'v0_0': [], 'v6_0': [5, 5, 4, 2, 1]})
except Exception as e:
print(f"onnx fail\n{e}")
model_simp, check = simplify(model)
onnx.save(model_simp, "./model_simp.onnx")
assert check, "Simplified ONNX model could not be validated"
try:
compile_onnx(model_simp, {'v0_0': [], 'v6_0': [5, 5, 4, 2, 1]})
except Exception as e:
print(f"onnx-simplify fail\n{e}")
Triage
- needs-triage
Could you provide the simplified model for debugging?
sorry for the late response. below is the simplified model model-sim.zip In fact, the simplified model seems correct but the original model can't pass the compilation. BTW, the simplified tool I used is here: https://github.com/daquexian/onnx-simplifier
@xhmelon
Hello, sorry to bother you again but I still feel confused about this bug. @xhmelon
Is there any new progress on this issue?
I know maybe you have no time to investigate this because of your busy schedule.
From the original graph, it seems that the op on the right side of the pad should not be a list containing 10 pad data. It might be due to some optimizations done by ONNX Simplifier, which defaults the pad to 10 data items. Judging from your error log, the issue is that your input data is 5-dimensional, but your pad attribute only has four pairs .
Hello, sorry to bother you again but I still feel confused about this bug. @xhmelon
Is there any new progress on this issue?
I know maybe you have no time to investigate this because of your busy schedule.
@shaoyuyoung Sorry for the late reply, I have been too busy to work on this issue since then. I will continue to debug and expect to solve it this week. I believe this issue is caused by shape broadcasting like #16891.
Hi @shaoyuyoung ,
The output shape from then branch of If node is 5×5×3×4, while the else branch is 5×5×3×4×1. The ONNX frontend in TVM attempts to broadcast the lower dimensions between these branches, which is irrational for our case. Since the predicate is a constant True, I added a check to skip the broadcast when the predicate is constant. This workaround resolves the issue in our case, but the source of the test case is still important.
The comment in the broadcast code explains:
# Sometimes pytorch to onnx will insert silly if statements that produce dynamic ranks.
# Often these dont contribute anything. If we see a dynamic rank output, try to unify
# them so we can continue without breaking.
I’m wondering whether this case is automatically generated by PyTorch, as suggested in the comment, or if it’s designed intentionally.
hi, @xhmelon really thank u for your effort. If I understand correctly, you are asking why (how) the model is generated in this case?
honestly, I first define a Pytorch model like the second graph in my original issue. Then I convert the PyTorch to onnx and get a model like the first graph.