tvm
tvm copied to clipboard
[Bug][Relay] cannot squeeze axis with dimension not equal to 1 at Keras frontend
For the LSTM below, when batch_size != 1 (i.e., the size of first dimension input), compiling will lead to an unexpected crash and throw Check failed: *axis_ptr == 1 (2 vs. 1) : cannot squeeze axis with dimension not equal to 1
Notice that, when batch_size ==1, TVM can run well.
Question:
For the LSTM model, why batch_size !=1 leads to a crash?
Compile successfully
Actual behavior
Traceback (most recent call last):
File "test.py", line 18, in <module>
model = relay.build_module.create_executor("vm", mod, tvm.cpu(0), 'llvm', params).evaluate()
File "/workplace/software/tvm/tvm/python/tvm/relay/backend/interpreter.py", line 171, in evaluate
return self._make_executor()
File "/workplace/software/tvm/tvm/python/tvm/relay/backend/vm.py", line 219, in _make_executor
self.executable = compile(self.mod, self.target)
File "/workplace/software/tvm/tvm/python/tvm/relay/backend/vm.py", line 67, in compile
compiler.lower(mod, target, target_host)
File "/workplace/software/tvm/tvm/python/tvm/relay/backend/vm.py", line 126, in lower
self._lower(mod, raw_targets)
File "/workplace/software/tvm/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
15: TVMFuncCall
14: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::relay::vm::VMCompiler::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::$_0> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
13: tvm::relay::vm::VMCompiler::Lower(tvm::IRModule, tvm::runtime::Array<tvm::Target, void> const&)
12: tvm::relay::vm::VMCompiler::LowerImpl(tvm::IRModule)
11: tvm::relay::vm::VMCompiler::OptimizeModuleImpl(tvm::IRModule)
10: tvm::transform::Pass::operator()(tvm::IRModule) const
9: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
8: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
7: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
6: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
5: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
4: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
3: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::$_2>(tvm::relay::transform::InferType()::$_2)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
2: tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function)
1: tvm::relay::TypeSolver::Solve()
0: _ZN3tvm7runtime6detail
19: TVMFuncCall
18: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::relay::vm::VMCompiler::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::$_0> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
17: tvm::relay::vm::VMCompiler::Lower(tvm::IRModule, tvm::runtime::Array<tvm::Target, void> const&)
16: tvm::relay::vm::VMCompiler::LowerImpl(tvm::IRModule)
15: tvm::relay::vm::VMCompiler::OptimizeModuleImpl(tvm::IRModule)
14: tvm::transform::Pass::operator()(tvm::IRModule) const
13: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
12: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
11: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
10: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
9: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
8: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
7: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::$_2>(tvm::relay::transform::InferType()::$_2)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
6: tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function)
5: tvm::relay::TypeSolver::Solve()
4: tvm::TypedEnvFunc<bool (tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>::operator()(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&) const
3: _ZN3tvm7runtime13Pac
2: tvm::runtime::TypedPackedFunc<bool (tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>::AssignTypedLambda<bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>(bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&))::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
1: tvm::relay::SqueezeRel(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)
0: _ZN3tvm7runtime6detail
File "/workplace/software/tvm/tvm/src/relay/analysis/type_solver.cc", line 643
TVMError:
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
Check failed: (false) is false: [07:40:29] /workplace/software/tvm/tvm/src/relay/op/tensor/transform.cc:2340:
---------------------------------------------------------------
Check failed: *axis_ptr == 1 (2 vs. 1) : cannot squeeze axis with dimension not equal to 1
Steps to reproduce
import tvm
import tvm.relay as relay
from tensorflow import keras
from tensorflow.keras import layers, models
input_shape = (2, 3, 4)
x = layers.Input(shape=input_shape[1:], dtype='float32')
layer = keras.layers.LSTM(units=2)
layer.set_weights(layer.get_weights())
y = layer(x)
model = models.Model(x, y)
model.summary()
mod, params = relay.frontend.from_keras(model, {'input_1': input_shape})
print(mod)
with tvm.transform.PassContext(opt_level=3):
model = relay.build_module.create_executor("vm", mod, tvm.cpu(0), 'llvm', params).evaluate()
Triage
- frontend:keras
cc @shingjan
This is a bug about the LSTM layer. I'm not very familiar with the RNN network. For the LSTM model, why does batch_size !=1 lead to a crash in TVM?
@echuraev Could you give me some suggestions? Thank you in advance.
@jikechao I suppose that the problem is the same as in #14868. In LSTM we have the same logic here as for Simple RNN.
@echuraev Thanks for your explanation, I'm leaning to fix this bug.
@jikechao I quickly took a look at this issue, and it looks like that one of the possible solution might be in using reshape instead of squeeze implementation in LSTM layer. The same trick I did in this PR: https://github.com/apache/tvm/pull/16526
@jikechao I quickly took a look at this issue, and it looks like that one of the possible solution might be in using
reshapeinstead ofsqueezeimplementation in LSTM layer. The same trick I did in this PR: #16526
@echuraev, Thank you! I submitted a PR to fix it. Could you help me review it?