tvm icon indicating copy to clipboard operation
tvm copied to clipboard

[Bug] TVM cannot build the model correctly: InternalError: Check failed: value <= support::kMaxFloat16

Open coffezhou opened this issue 8 months ago • 0 comments

Expected behavior

TVM should build the model correctly.

Actual behavior

Traceback (most recent call last):
  File "/home/carla/Documents/test_tvm/0312/test_relax2.py", line 81, in <module>
    ex = relax.build(tvm_model, target="llvm")
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/carla/Documents/tvm/python/tvm/relax/vm_build.py", line 259, in build
    return _vmlink(
           ^^^^^^^^
  File "/home/carla/Documents/tvm/python/tvm/relax/vm_build.py", line 154, in _vmlink
    lib = tvm.tir.build(tir_mod, target=target, pipeline=tir_pipeline)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/carla/Documents/tvm/python/tvm/tir/build.py", line 173, in build
    mod = pipeline(mod)
          ^^^^^^^^^^^^^
  File "/home/carla/Documents/tvm/python/tvm/ir/transform.py", line 238, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "tvm/_ffi/_cython/./packed_func.pxi", line 339, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 270, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./packed_func.pxi", line 259, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 185, in tvm._ffi._cy3.core.CHECK_CALL
  File "/home/carla/Documents/tvm/python/tvm/_ffi/base.py", line 468, in raise_last_ffi_error
    raise py_err
  File "tvm/_ffi/_cython/./packed_func.pxi", line 56, in tvm._ffi._cy3.core.tvm_callback
  File "/home/carla/Documents/tvm/python/tvm/tir/pipeline.py", line 122, in _pipeline
    mod = tvm.ir.transform.Sequential(passes)(mod)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/carla/Documents/tvm/python/tvm/ir/transform.py", line 238, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "tvm/_ffi/_cython/./packed_func.pxi", line 339, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 270, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./packed_func.pxi", line 259, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 185, in tvm._ffi._cy3.core.CHECK_CALL
  File "/home/carla/Documents/tvm/python/tvm/_ffi/base.py", line 468, in raise_last_ffi_error
    raise py_err
tvm.error.InternalError: Traceback (most recent call last):
  57: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  56: tvm::transform::Pass::operator()(tvm::IRModule) const
  55: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  54: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  53: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  52: tvm::tir::transform::PrimFuncPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  51: _ZN3tvm7runtime13PackedFuncObj
  50: tvm::runtime::TypedPackedFunc<tvm::tir::PrimFunc (tvm::tir::PrimFunc, tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::tir::transform::Simplify()::{lambda(tvm::tir::PrimFunc, tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::tir::transform::Simplify()::{lambda(tvm::tir::PrimFunc, tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
  49: tvm::arith::StmtSimplifier::Apply(tvm::tir::PrimFunc, tvm::arith::Analyzer*, tvm::runtime::Optional<tvm::arith::SimplifyConfig>)
  48: tvm::arith::StmtSimplifier::VisitStmt(tvm::tir::Stmt const&)
  47: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime9Ob
  46: tvm::arith::StmtSimplifier::VisitStmt(tvm::tir::Stmt const&)
  45: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime9Ob
  44: tvm::arith::IRMutatorWithAnalyzer::VisitStmt_(tvm::tir::BlockNode const*)
  43: tvm::arith::StmtSimplifier::VisitStmt(tvm::tir::Stmt const&)
  42: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime9Ob
  41: tvm::runtime::ObjectPtr<tvm::runtime::Object> tvm::runtime::Array<tvm::tir::Stmt, void>::MapHelper<tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::runtime::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1}, tvm::tir::Stmt>(tvm::runtime::ObjectPtr<tvm::runtime::Object>, tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::runtime::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1})
  40: tvm::arith::StmtSimplifier::VisitStmt(tvm::tir::Stmt const&)
  39: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime
  38: tvm::arith::StmtSimplifier::VisitStmt_(tvm::tir::ForNode const*)
  37: tvm::arith::IRMutatorWithAnalyzer::VisitStmt_(tvm::tir::ForNode const*)
  36: tvm::arith::StmtSimplifier::VisitStmt(tvm::tir::Stmt const&)
  35: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime
  34: tvm::arith::StmtSimplifier::VisitStmt_(tvm::tir::ForNode const*)
  33: tvm::arith::IRMutatorWithAnalyzer::VisitStmt_(tvm::tir::ForNode const*)
  32: tvm::arith::StmtSimplifier::VisitStmt(tvm::tir::Stmt const&)
  31: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime
  30: tvm::arith::StmtSimplifier::VisitStmt_(tvm::tir::ForNode const*)
  29: tvm::arith::IRMutatorWithAnalyzer::VisitStmt_(tvm::tir::ForNode const*)
  28: tvm::arith::StmtSimplifier::VisitStmt(tvm::tir::Stmt const&)
  27: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime
  26: tvm::arith::StmtSimplifier::VisitStmt_(tvm::tir::ForNode const*)
  25: tvm::arith::IRMutatorWithAnalyzer::VisitStmt_(tvm::tir::ForNode const*)
  24: tvm::arith::StmtSimplifier::VisitStmt(tvm::tir::Stmt const&)
  23: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime9Ob
  22: tvm::arith::StmtSimplifier::VisitStmt(tvm::tir::Stmt const&)
  21: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime9Ob
  20: tvm::arith::IRMutatorWithAnalyzer::VisitStmt_(tvm::tir::BlockNode const*)
  19: tvm::arith::StmtSimplifier::VisitStmt(tvm::tir::Stmt const&)
  18: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime
  17: tvm::arith::StmtSimplifier::VisitStmt_(tvm::tir::BufferStoreNode const*)
  16: tvm::arith::StmtSimplifier::VisitExpr(tvm::PrimExpr const&)
  15: tvm::arith::Analyzer::Simplify(tvm::PrimExpr const&, int)
  14: tvm::arith::CanonicalSimplifier::operator()(tvm::PrimExpr const&)
  13: non-virtual thunk to tvm::arith::CanonicalSimplifier::Impl::VisitExpr(tvm::PrimExpr const&)
  12: tvm::arith::RewriteSimplifier::Impl::VisitExpr(tvm::PrimExpr const&)
  11: _ZZN3tvm3tir11ExprFunctorIFNS_8PrimExprERKS2_EE10InitVTableEvENUlRKNS_7runt
  10: tvm::arith::CanonicalSimplifier::Impl::VisitExpr_(tvm::tir::DivNode const*)
  9: tvm::arith::RewriteSimplifier::Impl::VisitExpr_(tvm::tir::DivNode const*)
  8: non-virtual thunk to tvm::arith::CanonicalSimplifier::Impl::VisitExpr(tvm::PrimExpr const&)
  7: tvm::arith::RewriteSimplifier::Impl::VisitExpr(tvm::PrimExpr const&)
  6: _ZZN3tvm3tir11ExprFunctorIFNS_8PrimExprERKS2_EE10InitVTableEvENUlRKNS_7runtime
  5: tvm::arith::CanonicalSimplifier::Impl::VisitExpr_(tvm::tir::CastNode const*)
  4: tvm::arith::RewriteSimplifier::Impl::VisitExpr_(tvm::tir::CastNode const*)
  3: tvm::cast(tvm::runtime::DataType const&, tvm::PrimExpr, tvm::Span) [clone .localalias]
  2: tvm::PrimExpr tvm::tir::make_const<long, void>(tvm::runtime::DataType, long, tvm::Span)
  1: tvm::PrimExpr tvm::tir::MakeConstScalar<long>(tvm::runtime::DataType, long, tvm::Span)
  0: tvm::FloatImm::FloatImm(tvm::runtime::DataType, double, tvm::Span)
  File "/home/carla/Documents/tvm/src/ir/expr.cc", line 127
InternalError: Check failed: value <= support::kMaxFloat16 (261121 vs. 65504) : ValueError: Literal value 261121 exceeds maximum of float16

Environment

OS: Ubuntu 20.04 TVM: 0.20.dev0 (6e8c367)

Steps to reproduce

This bug can be reproduced by the following code with the model in the attachment. For the model, it can be correctly ran by onnxruntime. However, an InternalError occurs when TVM builds this model.

from typing import Dict, List, Literal, Optional
import sys

import numpy as np
import onnx
import onnxruntime
from onnx import ModelProto, TensorProto, helper, mapping

import tvm
from tvm import relax
from tvm.relax.frontend.onnx import from_onnx

import argparse

bg = np.random.MT19937(0)
rg = np.random.Generator(bg)

def generate_random_inputs(
    model: ModelProto, inputs: Optional[Dict[str, np.ndarray]] = None
) -> Dict[str, np.ndarray]:
    input_values = {}
    # Iterate through model inputs and extract their shape.
    for i in model.graph.input:
        if inputs is not None and i.name in inputs and inputs[i.name] is not None:
            input_values[i.name] = inputs[i.name]
            continue
        shape = []
        for dim in i.type.tensor_type.shape.dim:
            shape.append(dim.dim_value)

        input_values[i.name] = generate_random_value(shape, i.type.tensor_type.elem_type)

    return input_values


def generate_random_value(shape, elem_type) -> np.ndarray:

    # Extract datatype for the input.
    if elem_type:
        dtype = str(helper.tensor_dtype_to_np_dtype(elem_type))
    else:
        dtype = "float32"

    # Generate random inputs for each input.
    if dtype == "bool":
        # random_value = np.random.choice(a=[False, True], size=shape)
        random_value = rg.choice(a=[False, True], size=shape)
    elif dtype.startswith("int"):
        # Keep non-zero values
        random_value = rg.integers(low=-63, high=63, size=shape).astype(dtype)
        random_value[random_value <= 0] -= 1
    else:
        random_value = rg.standard_normal(size=shape).astype(dtype)

    return random_value
    
model_path = "model.onnx"
model = onnx.load(model_path)

inputs: Optional[Dict[str, np.ndarray]] = None
inputs = generate_random_inputs(model, inputs)

try:
    ort_session = onnxruntime.InferenceSession(
        model.SerializeToString(), providers=["CPUExecutionProvider"]
    )
    ort_output = ort_session.run([], inputs)
except:
    print("This model cannot be executed by onnxruntime!")
    sys.exit(1)

print(ort_output)
    
tvm_model = from_onnx(model, keep_params_in_input=True)
tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model)
tvm_model = relax.transform.LegalizeOps()(tvm_model)

tvm_model, params = relax.frontend.detach_params(tvm_model)

with tvm.transform.PassContext(opt_level=0):
    ex = relax.build(tvm_model, target="llvm")
    vm = relax.VirtualMachine(ex, tvm.cpu())

model.zip

  • needs-triage

coffezhou avatar Mar 13 '25 02:03 coffezhou