tvm icon indicating copy to clipboard operation
tvm copied to clipboard

[Bug] Constant folding cannot process onnx model correctly: InternalError: Check failed: pb->value != 0 (0 vs. 0) : Divide by zero

Open coffezhou opened this issue 8 months ago • 4 comments

Expected behavior

TVM should build the model correctly.

Actual behavior

Traceback (most recent call last):
  File "/home/carla/Documents/test_tvm/0312/test_relax2.py", line 81, in <module>
    ex = relax.build(tvm_model, target="llvm")
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/carla/Documents/tvm/python/tvm/relax/vm_build.py", line 253, in build
    mod = relax_pipeline(mod)
          ^^^^^^^^^^^^^^^^^^^
  File "/home/carla/Documents/tvm/python/tvm/ir/transform.py", line 238, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "tvm/_ffi/_cython/./packed_func.pxi", line 339, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 270, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./packed_func.pxi", line 259, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 185, in tvm._ffi._cy3.core.CHECK_CALL
  File "/home/carla/Documents/tvm/python/tvm/_ffi/base.py", line 468, in raise_last_ffi_error
    raise py_err
  File "tvm/_ffi/_cython/./packed_func.pxi", line 56, in tvm._ffi._cy3.core.tvm_callback
  File "/home/carla/Documents/tvm/python/tvm/relax/pipeline.py", line 103, in _pipeline
    mod = seq(mod)
          ^^^^^^^^
  File "/home/carla/Documents/tvm/python/tvm/ir/transform.py", line 238, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "tvm/_ffi/_cython/./packed_func.pxi", line 339, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 270, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./packed_func.pxi", line 259, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 185, in tvm._ffi._cy3.core.CHECK_CALL
  File "/home/carla/Documents/tvm/python/tvm/_ffi/base.py", line 468, in raise_last_ffi_error
    raise py_err
tvm.error.InternalError: Traceback (most recent call last):
  40: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  39: tvm::transform::Pass::operator()(tvm::IRModule) const
  38: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  37: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  36: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  35: tvm::relax::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  34: _ZN3tvm7runtime13PackedFuncObj
  33: tvm::runtime::TypedPackedFunc<tvm::relax::Function (tvm::relax::Function, tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relax::transform::RewriteDataflowReshape()::{lambda(tvm::relax::Function, tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relax::transform::RewriteDataflowReshape()::{lambda(tvm::relax::Function, tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
  32: tvm::relax::RewriteDataflowReshape(tvm::relax::Function const&, tvm::IRModule const&)
  31: tvm::relax::ExprMutator::VisitExpr(tvm::RelaxExpr const&)
  30: _ZZN3tvm5relax11ExprFunctorIFNS_9RelaxExprERKS2_EE10InitVTableEvENUlRKNS_7r
  29: tvm::relax::ExprMutator::VisitExpr_(tvm::relax::FunctionNode const*)
  28: tvm::relax::ExprMutator::VisitWithNewScope(tvm::RelaxExpr const&, tvm::runtime::Optional<tvm::runtime::Array<tvm::relax::Var, void> >)
  27: tvm::relax::ExprMutator::VisitExpr(tvm::RelaxExpr const&)
  26: _ZZN3tvm5relax11ExprFunctorIFNS_9RelaxExprERKS2_EE10InitVTableEvENUlRKNS_7r
  25: tvm::relax::ExprMutator::VisitExpr_(tvm::relax::SeqExprNode const*)
  24: tvm::relax::DataflowReshapeRewriter::VisitBindingBlock(tvm::relax::BindingBlock const&)
  23: tvm::relax::ExprMutator::VisitBindingBlock_(tvm::relax::DataflowBlockNode const*)
  22: tvm::relax::ExprMutator::VisitBinding(tvm::relax::Binding const&)
  21: tvm::relax::DataflowReshapeRewriter::VisitBinding_(tvm::relax::VarBindingNode const*)
  20: tvm::relax::ExprMutator::VisitBinding_(tvm::relax::VarBindingNode const*)
  19: tvm::relax::ExprMutator::VisitBinding_(tvm::relax::VarBindingNode const*, tvm::relax::TupleGetItemNode const*)
  18: tvm::relax::ExprMutator::VisitExpr(tvm::RelaxExpr const&)
  17: _ZZN3tvm5relax11ExprFunctorIFNS_9RelaxExprERKS2_EE10InitVTableEvENUlRKNS_7r
  16: tvm::relax::DataflowReshapeRewriter::VisitExpr_(tvm::relax::CallNode const*)
  15: tvm::relax::DataflowReshapeRewriter::IsCallingTIRReshape(tvm::relax::CallNode const*, tvm::RelaxExpr)
  14: tvm::relax::HasReshapePattern(tvm::tir::PrimFunc const&)
  13: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  12: tvm::relax::HasReshapePattern(tvm::tir::PrimFunc const&)::ReshapeDetector::VisitStmt_(tvm::tir::BlockRealizeNode const*)
  11: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  10: tvm::relax::HasReshapePattern(tvm::tir::PrimFunc const&)::ReshapeDetector::VisitStmt_(tvm::tir::BlockNode const*)
  9: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  8: tvm::relax::HasReshapePattern(tvm::tir::PrimFunc const&)::ReshapeDetector::VisitStmt_(tvm::tir::ForNode const*)
  7: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  6: tvm::relax::HasReshapePattern(tvm::tir::PrimFunc const&)::ReshapeDetector::VisitStmt_(tvm::tir::ForNode const*)
  5: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  4: tvm::relax::HasReshapePattern(tvm::tir::PrimFunc const&)::ReshapeDetector::VisitStmt_(tvm::tir::BlockRealizeNode const*)
  3: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  2: tvm::relax::HasReshapePattern(tvm::tir::PrimFunc const&)::ReshapeDetector::VisitStmt_(tvm::tir::BlockNode const*)
  1: tvm::floormod(tvm::PrimExpr, tvm::PrimExpr, tvm::Span)
  0: tvm::runtime::Optional<tvm::PrimExpr> tvm::arith::TryConstFold<tvm::tir::FloorMod>(tvm::PrimExpr, tvm::PrimExpr)
  File "/home/carla/Documents/tvm/src/arith/const_fold.h", line 321
InternalError: Check failed: pb->value != 0 (0 vs. 0) : Divide by zero

Environment

OS: Ubuntu 20.04 TVM: 0.20.dev0 (6e8c367)

Steps to reproduce

This bug can be reproduced by the following code with the model in the attachment. For the model, it can be correctly ran by onnxruntime. However, an InternalError occurs when TVM builds this model.

from typing import Dict, List, Literal, Optional
import sys

import numpy as np
import onnx
import onnxruntime
from onnx import ModelProto, TensorProto, helper, mapping

import tvm
from tvm import relax
from tvm.relax.frontend.onnx import from_onnx

import argparse

bg = np.random.MT19937(0)
rg = np.random.Generator(bg)

def generate_random_inputs(
    model: ModelProto, inputs: Optional[Dict[str, np.ndarray]] = None
) -> Dict[str, np.ndarray]:
    input_values = {}
    # Iterate through model inputs and extract their shape.
    for i in model.graph.input:
        if inputs is not None and i.name in inputs and inputs[i.name] is not None:
            input_values[i.name] = inputs[i.name]
            continue
        shape = []
        for dim in i.type.tensor_type.shape.dim:
            shape.append(dim.dim_value)

        input_values[i.name] = generate_random_value(shape, i.type.tensor_type.elem_type)

    return input_values


def generate_random_value(shape, elem_type) -> np.ndarray:

    # Extract datatype for the input.
    if elem_type:
        dtype = str(helper.tensor_dtype_to_np_dtype(elem_type))
    else:
        dtype = "float32"

    # Generate random inputs for each input.
    if dtype == "bool":
        # random_value = np.random.choice(a=[False, True], size=shape)
        random_value = rg.choice(a=[False, True], size=shape)
    elif dtype.startswith("int"):
        # Keep non-zero values
        random_value = rg.integers(low=-63, high=63, size=shape).astype(dtype)
        random_value[random_value <= 0] -= 1
    else:
        random_value = rg.standard_normal(size=shape).astype(dtype)

    return random_value
    
model_path = "model.onnx"
model = onnx.load(model_path)

inputs: Optional[Dict[str, np.ndarray]] = None
inputs = generate_random_inputs(model, inputs)

try:
    ort_session = onnxruntime.InferenceSession(
        model.SerializeToString(), providers=["CPUExecutionProvider"]
    )
    ort_output = ort_session.run([], inputs)
except:
    print("This model cannot be executed by onnxruntime!")
    sys.exit(1)

print(ort_output)
    
tvm_model = from_onnx(model, keep_params_in_input=True)
tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model)
tvm_model = relax.transform.LegalizeOps()(tvm_model)

tvm_model, params = relax.frontend.detach_params(tvm_model)

with tvm.transform.PassContext(opt_level=0):
    ex = relax.build(tvm_model, target="llvm")
    vm = relax.VirtualMachine(ex, tvm.cpu())

model.zip

  • needs-triage

coffezhou avatar Mar 13 '25 05:03 coffezhou

For relax.build() to work properly, it expects the model to be lowered to TensorIR before compilation.

Before calling relax.build(), add this transformation: tvm_model = relax.transform.LowerToTensorIR()(tvm_model)

It should work.

Kushagra-88 avatar Mar 13 '25 12:03 Kushagra-88

For relax.build() to work properly, it expects the model to be lowered to TensorIR before compilation.

Before calling relax.build(), add this transformation: tvm_model = relax.transform.LowerToTensorIR()(tvm_model)

It should work.

@Kushagra-88 Thanks for your reply! I have tried your method, but it gives the following output:

tvm_model = relax.transform.LowerToTensorIR()(tvm_model)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: module 'tvm.relax.transform' has no attribute 'LowerToTensorIR'

I also try to search 'LowerToTensorIR' in the TVM repository on github. Unfortunately, LowerToTensorIR is not found in the current TVM repository. As stated in document, by using 'tvm_model = relax.transform.LegalizeOps()(tvm_model)' , the model should be lowered to TensorIR.

coffezhou avatar Mar 15 '25 12:03 coffezhou

For relax.build() to work properly, it expects the model to be lowered to TensorIR before compilation. Before calling relax.build(), add this transformation: tvm_model = relax.transform.LowerToTensorIR()(tvm_model) It should work.

@Kushagra-88 Thanks for your reply! I have tried your method, but it gives the following output:

tvm_model = relax.transform.LowerToTensorIR()(tvm_model) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ AttributeError: module 'tvm.relax.transform' has no attribute 'LowerToTensorIR' I also try to search 'LowerToTensorIR' in the TVM repository on github. Unfortunately, LowerToTensorIR is not found in the current TVM repository. As stated in document, by using 'tvm_model = relax.transform.LegalizeOps()(tvm_model)' , the model should be lowered to TensorIR.

Did you try LegalizeOps()? Did it work?

SrivaniJayanthi avatar Mar 17 '25 05:03 SrivaniJayanthi

For relax.build() to work properly, it expects the model to be lowered to TensorIR before compilation. Before calling relax.build(), add this transformation: tvm_model = relax.transform.LowerToTensorIR()(tvm_model) It should work.

@Kushagra-88 Thanks for your reply! I have tried your method, but it gives the following output: tvm_model = relax.transform.LowerToTensorIR()(tvm_model) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ AttributeError: module 'tvm.relax.transform' has no attribute 'LowerToTensorIR' I also try to search 'LowerToTensorIR' in the TVM repository on github. Unfortunately, LowerToTensorIR is not found in the current TVM repository. As stated in document, by using 'tvm_model = relax.transform.LegalizeOps()(tvm_model)' , the model should be lowered to TensorIR.

Did you try LegalizeOps()? Did it work?

No, the bug is still occurred after using LegalizeOps().

coffezhou avatar Mar 17 '25 11:03 coffezhou