tvm
tvm copied to clipboard
[Bug] InternalError: Check failed: (it != slot_map_.end()) is false: Var mis not defined in the function but is referenced by m * n during VM Shape Lowering
trafficstars
After applying LiftTransformParams transformation, during relax VM transformation, particularly in the VM Shape Lowering phase, the following error occurs:
File "/software/tvm/src/relax/backend/vm/vm_shape_lower.cc", line 310
InternalError: Check failed: (it != slot_map_.end()) is false: Var mis not defined in the function but is referenced by m * n
The error seems to indicate an issue with variable scope, where m is used but not recognized in the expected scope during shape transformation. The variable m is defined within main, and its shape is referenced correctly in tir_vars, yet it still causes a shape resolution failure.
Steps to reproduce
import tvm
from tvm import relax
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def add(var_weight: T.handle, var_T_add: T.handle, m: T.int64, n: T.int64):
T.func_attr({"tir.noalias": T.bool(True)})
weight = T.match_buffer(var_weight, (m * n,))
T_add = T.match_buffer(var_T_add, (m * n,))
for ax0 in range(m * n):
with T.block("T_add"):
v_ax0 = T.axis.spatial(m * n, ax0)
T.reads(weight[v_ax0])
T.writes(T_add[v_ax0])
T_add[v_ax0] = weight[v_ax0] + T.float32(1)
@R.function
def main(x: R.Tensor(("m", "n"), dtype="float32"), weight: R.Tensor(("m * n",), dtype="float32")) -> R.Tensor(("m * n", 1, 1, 1), dtype="float32"):
m = T.int64()
n = T.int64()
R.func_attr({"num_input": 1})
cls = Module
with R.dataflow():
gv = R.call_tir(cls.add, (weight,), out_sinfo=R.Tensor((m * n,), dtype="float32"), tir_vars=R.shape([m, n]))
R.output(gv)
return gv
mod = Module
with tvm.transform.PassContext(disabled_pass=["RemoveUnusedParameters"]):
mod = relax.transform.FuseTIR()(mod)
mod = tvm.relax.transform.LegalizeOps()(mod)
mod = relax.transform.LambdaLift()(mod)
mod = relax.transform.LiftTransformParams()(mod)
ex = relax.build(mod, target='llvm')