tvm
tvm copied to clipboard
[Bug] Inconsistent module structure and InternalError: Check failed: (!require_value_computed) is false: PrimExpr m is not computed
Applying the transformations LiftTransformParams(), there is an inconsistency in the model structure between the sequential transformation (mod_seq) and the individual transformations (mod). And build the module after transformation, it will crash.
The error may relate to how m is handled as a dynamic shape or a required computed value, which may not be properly resolved during the transformation and build processes.
Actual behavior
File "/software/tvm/src/relax/backend/vm/vm_shape_lower.cc", line 463
InternalError: Check failed: (!require_value_computed) is false: PrimExpr m is not computed
Steps to reproduce
import tvm
from tvm import relax
import numpy as np
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def tir_acos(var_x: T.handle, var_compute: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
m = T.int64()
x = T.match_buffer(var_x, (T.int64(16), m, T.int64(3), T.int64(3)))
compute = T.match_buffer(var_compute, (T.int64(16), m, T.int64(3), T.int64(3)))
# with T.block("root"):
for i0, i1, i2, i3 in T.grid(T.int64(16), m, T.int64(3), T.int64(3)):
with T.block("compute"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(x[v_i0, v_i1, v_i2, v_i3])
T.writes(compute[v_i0, v_i1, v_i2, v_i3])
compute[v_i0, v_i1, v_i2, v_i3] = T.acos(x[v_i0, v_i1, v_i2, v_i3])
@R.function
def main(x: R.Tensor((1, 16, 224, "n"), dtype="float32"), w1: R.Tensor((16, "m", 3, 3), dtype="float32"), w2: R.Tensor((16, "m", 3, 3), dtype="float32")) -> R.Tensor((16, "m", 3, 3), dtype="float32"):
m = T.int64()
n = T.int64()
R.func_attr({"num_input": 1})
cls = Module
with R.dataflow():
gv = R.call_tir(cls.tir_acos, (w1,), out_sinfo=R.Tensor((16, m, 3, 3), dtype="float32"))
R.output(gv)
return gv
mod = Module
mod_seq = tvm.transform.Sequential([relax.transform.LiftTransformParams(), ])(mod)
mod = relax.transform.LiftTransformParams()(mod)
ex = relax.build(mod, target='llvm')
tvm.ir.assert_structural_equal(mod_seq, mod)