tvm icon indicating copy to clipboard operation
tvm copied to clipboard

[Bug] InternalError: Check failed: (it != slot_map_.end()) is false: Var mis not defined in the function but is referenced by m * n during VM Shape Lowering

Open Thrsu opened this issue 1 year ago • 0 comments
trafficstars

After applying LiftTransformParams transformation, during relax VM transformation, particularly in the VM Shape Lowering phase, the following error occurs:

File "/software/tvm/src/relax/backend/vm/vm_shape_lower.cc", line 310
InternalError: Check failed: (it != slot_map_.end()) is false: Var mis not defined in the function but is referenced by m * n

The error seems to indicate an issue with variable scope, where m is used but not recognized in the expected scope during shape transformation. The variable m is defined within main, and its shape is referenced correctly in tir_vars, yet it still causes a shape resolution failure.

Steps to reproduce

import tvm
from tvm import relax
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def add(var_weight: T.handle, var_T_add: T.handle, m: T.int64, n: T.int64):
        T.func_attr({"tir.noalias": T.bool(True)})
        weight = T.match_buffer(var_weight, (m * n,))
        T_add = T.match_buffer(var_T_add, (m * n,))
        for ax0 in range(m * n):
            with T.block("T_add"):
                v_ax0 = T.axis.spatial(m * n, ax0)
                T.reads(weight[v_ax0])
                T.writes(T_add[v_ax0])
                T_add[v_ax0] = weight[v_ax0] + T.float32(1)

    @R.function
    def main(x: R.Tensor(("m", "n"), dtype="float32"), weight: R.Tensor(("m * n",), dtype="float32")) -> R.Tensor(("m * n", 1, 1, 1), dtype="float32"):
        m = T.int64()
        n = T.int64()
        R.func_attr({"num_input": 1})
        cls = Module
        with R.dataflow():
            gv = R.call_tir(cls.add, (weight,), out_sinfo=R.Tensor((m * n,), dtype="float32"), tir_vars=R.shape([m, n]))
            R.output(gv)
        return gv

mod = Module
with tvm.transform.PassContext(disabled_pass=["RemoveUnusedParameters"]):
    mod = relax.transform.FuseTIR()(mod)
mod = tvm.relax.transform.LegalizeOps()(mod)
mod = relax.transform.LambdaLift()(mod)
mod = relax.transform.LiftTransformParams()(mod)
ex = relax.build(mod, target='llvm')

Thrsu avatar Oct 27 '24 14:10 Thrsu