tvm
tvm copied to clipboard
[Bug] LegalizeOps failed: InternalError: Check failed: (strides[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i)) is false
Expected behavior
TVM should build the model correctly.
Actual behavior
Traceback (most recent call last):
File "/home/carla/Documents/test_tvm/0321/test_relax2.py", line 75, in <module>
tvm_model = relax.transform.LegalizeOps()(tvm_model)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/carla/Documents/tvm/python/tvm/ir/transform.py", line 238, in __call__
return _ffi_transform_api.RunPass(self, mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "tvm/_ffi/_cython/./packed_func.pxi", line 339, in tvm._ffi._cy3.core.PackedFuncBase.__call__
File "tvm/_ffi/_cython/./packed_func.pxi", line 270, in tvm._ffi._cy3.core.FuncCall
File "tvm/_ffi/_cython/./packed_func.pxi", line 259, in tvm._ffi._cy3.core.FuncCall3
File "tvm/_ffi/_cython/./base.pxi", line 185, in tvm._ffi._cy3.core.CHECK_CALL
File "/home/carla/Documents/tvm/python/tvm/_ffi/base.py", line 468, in raise_last_ffi_error
raise py_err
File "/home/carla/Documents/tvm/src/relax/transform/legalize_ops.cc", line 398, in operator()
mod = LegalizeMutator(mod, cmap, enable_warning).Transform();
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/carla/Documents/tvm/src/relax/transform/legalize_ops.cc", line 74, in tvm::relax::LegalizeMutator::Transform()
auto updated_func = Downcast<Function>(this->VisitExpr(func));
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/carla/Documents/tvm/src/relax/transform/legalize_ops.cc", line 343, in tvm::relax::LegalizeMutator::VisitExpr_(tvm::relax::CallNode const*)
Expr legalized = legalization_func(builder_, visited_call);
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "tvm/_ffi/_cython/./packed_func.pxi", line 56, in tvm._ffi._cy3.core.tvm_callback
File "/home/carla/Documents/tvm/python/tvm/relax/transform/legalize_ops/index.py", line 62, in _strided_slice
return bb.call_te(
^^^^^^^^^^^
File "/home/carla/Documents/tvm/python/tvm/relax/block_builder.py", line 356, in call_te
tir_func, call_args, output_sinfo, tir_vars = gen_call_tir_inputs(func, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/carla/Documents/tvm/python/tvm/relax/utils.py", line 354, in gen_call_tir_inputs
te_out = func(*te_args, **te_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/carla/Documents/tvm/python/tvm/topi/transform.py", line 228, in strided_slice
return cpp.strided_slice(a, begin, end, strides, axes, slice_mode, assume_inbound)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "tvm/_ffi/_cython/./packed_func.pxi", line 339, in tvm._ffi._cy3.core.PackedFuncBase.__call__
File "tvm/_ffi/_cython/./packed_func.pxi", line 284, in tvm._ffi._cy3.core.FuncCall
File "tvm/_ffi/_cython/./base.pxi", line 185, in tvm._ffi._cy3.core.CHECK_CALL
File "/home/carla/Documents/tvm/src/topi/transform.cc", line 195, in operator()
*rv = strided_slice_with_axes(x, begin_static, end_static, strides_static, axes, slice_mode);
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/carla/Documents/tvm/include/tvm/topi/transform.h", line 899, in tvm::topi::strided_slice_with_axes(tvm::te::Tensor const&, tvm::runtime::Array<tvm::Integer, void> const&, tvm::runtime::Array<tvm::Integer, void> const&, tvm::runtime::Array<tvm::Integer, void> const&, tvm::runtime::Array<tvm::Integer, void> const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)
slice_mode, begin_expr);
^^^^^^
File "/home/carla/Documents/tvm/include/tvm/topi/detail/strided_slice.h", line 140, in tvm::topi::detail::StridedSliceOutputShape(tvm::runtime::Array<tvm::PrimExpr, void> const&, std::vector<long, std::allocator<long> > const&, std::vector<long, std::allocator<long> > const&, std::vector<long, std::allocator<long> > const&, tvm::runtime::Array<tvm::Integer, void> const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::Array<tvm::PrimExpr, void> const&, bool)
ICHECK(strides[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i))
^^^^^^^^^^^^^^^^^^^^^^^^^^^
tvm.error.InternalError: Traceback (most recent call last):
2: operator()
at /home/carla/Documents/tvm/src/topi/transform.cc:195
1: tvm::topi::strided_slice_with_axes(tvm::te::Tensor const&, tvm::runtime::Array<tvm::Integer, void> const&, tvm::runtime::Array<tvm::Integer, void> const&, tvm::runtime::Array<tvm::Integer, void> const&, tvm::runtime::Array<tvm::Integer, void> const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)
at /home/carla/Documents/tvm/include/tvm/topi/transform.h:899
0: tvm::topi::detail::StridedSliceOutputShape(tvm::runtime::Array<tvm::PrimExpr, void> const&, std::vector<long, std::allocator<long> > const&, std::vector<long, std::allocator<long> > const&, std::vector<long, std::allocator<long> > const&, tvm::runtime::Array<tvm::Integer, void> const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::Array<tvm::PrimExpr, void> const&, bool)
at /home/carla/Documents/tvm/include/tvm/topi/detail/strided_slice.h:140
File "/home/carla/Documents/tvm/include/tvm/topi/detail/strided_slice.h", line 140
InternalError: Check failed: (strides[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i)) is false: : Input [Begin=-1, End=1] is invalid for axis=0
Environment
OS: Ubuntu 20.04 TVM: 0.20.dev0 (f6236ce41)
Steps to reproduce
This bug can be reproduced by the following code with the model in the attachment. For the model, it can be correctly ran by onnxruntime. However, an InternalError occurs when TVM builds this model.
from typing import Dict, List, Literal, Optional
import sys
import numpy as np
import onnx
import onnxruntime
from onnx import ModelProto, TensorProto, helper, mapping
import tvm
from tvm import relax
from tvm.relax.frontend.onnx import from_onnx
import argparse
bg = np.random.MT19937(0)
rg = np.random.Generator(bg)
def generate_random_inputs(
model: ModelProto, inputs: Optional[Dict[str, np.ndarray]] = None
) -> Dict[str, np.ndarray]:
input_values = {}
# Iterate through model inputs and extract their shape.
for i in model.graph.input:
if inputs is not None and i.name in inputs and inputs[i.name] is not None:
input_values[i.name] = inputs[i.name]
continue
shape = []
for dim in i.type.tensor_type.shape.dim:
shape.append(dim.dim_value)
input_values[i.name] = generate_random_value(shape, i.type.tensor_type.elem_type)
return input_values
def generate_random_value(shape, elem_type) -> np.ndarray:
# Extract datatype for the input.
if elem_type:
dtype = str(helper.tensor_dtype_to_np_dtype(elem_type))
else:
dtype = "float32"
# Generate random inputs for each input.
if dtype == "bool":
# random_value = np.random.choice(a=[False, True], size=shape)
random_value = rg.choice(a=[False, True], size=shape)
elif dtype.startswith("int"):
# Keep non-zero values
random_value = rg.integers(low=-63, high=63, size=shape).astype(dtype)
random_value[random_value <= 0] -= 1
else:
random_value = rg.standard_normal(size=shape).astype(dtype)
return random_value
model_path = "model.onnx"
model = onnx.load(model_path)
inputs: Optional[Dict[str, np.ndarray]] = None
inputs = generate_random_inputs(model, inputs)
try:
ort_session = onnxruntime.InferenceSession(
model.SerializeToString(), providers=["CPUExecutionProvider"]
)
ort_output = ort_session.run([], inputs)
except:
print("This model cannot be executed by onnxruntime!")
sys.exit(1)
tvm_model = from_onnx(model, keep_params_in_input=True)
tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model)
tvm_model = relax.transform.LegalizeOps()(tvm_model)
tvm_model, params = relax.frontend.detach_params(tvm_model)
with tvm.transform.PassContext(opt_level=4):
ex = relax.build(tvm_model, target="llvm")
vm = relax.VirtualMachine(ex, tvm.cpu())