tvm icon indicating copy to clipboard operation
tvm copied to clipboard

[Bug] Inconsistent module structure and InternalError: Check failed: (!require_value_computed) is false: PrimExpr m is not computed

Open Thrsu opened this issue 1 year ago • 0 comments

Applying the transformations LiftTransformParams(), there is an inconsistency in the model structure between the sequential transformation (mod_seq) and the individual transformations (mod). And build the module after transformation, it will crash.

The error may relate to how m is handled as a dynamic shape or a required computed value, which may not be properly resolved during the transformation and build processes.

Actual behavior

  File "/software/tvm/src/relax/backend/vm/vm_shape_lower.cc", line 463
InternalError: Check failed: (!require_value_computed) is false: PrimExpr m is not computed

Steps to reproduce

import tvm
from tvm import relax
import numpy as np

from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def tir_acos(var_x: T.handle, var_compute: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        m = T.int64()
        x = T.match_buffer(var_x, (T.int64(16), m, T.int64(3), T.int64(3)))
        compute = T.match_buffer(var_compute, (T.int64(16), m, T.int64(3), T.int64(3)))
        # with T.block("root"):
        for i0, i1, i2, i3 in T.grid(T.int64(16), m, T.int64(3), T.int64(3)):
            with T.block("compute"):
                v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
                T.reads(x[v_i0, v_i1, v_i2, v_i3])
                T.writes(compute[v_i0, v_i1, v_i2, v_i3])
                compute[v_i0, v_i1, v_i2, v_i3] = T.acos(x[v_i0, v_i1, v_i2, v_i3])

    @R.function
    def main(x: R.Tensor((1, 16, 224, "n"), dtype="float32"), w1: R.Tensor((16, "m", 3, 3), dtype="float32"), w2: R.Tensor((16, "m", 3, 3), dtype="float32")) -> R.Tensor((16, "m", 3, 3), dtype="float32"):
        m = T.int64()
        n = T.int64()
        R.func_attr({"num_input": 1})
        cls = Module
        with R.dataflow():
            gv = R.call_tir(cls.tir_acos, (w1,), out_sinfo=R.Tensor((16, m, 3, 3), dtype="float32"))
            R.output(gv)
        return gv

mod = Module
mod_seq = tvm.transform.Sequential([relax.transform.LiftTransformParams(), ])(mod)
mod = relax.transform.LiftTransformParams()(mod)
ex = relax.build(mod, target='llvm')
tvm.ir.assert_structural_equal(mod_seq, mod)

Thrsu avatar Oct 27 '24 14:10 Thrsu