tvm icon indicating copy to clipboard operation
tvm copied to clipboard

[Bug] Inconsistency Module Structure in Relax Transform and Build Failure with InlinePrivateFunctions()

Open Thrsu opened this issue 1 year ago • 2 comments

When applying the relax.transform.InlinePrivateFunctions() optimization to a Relax module using both Sequential and direct application methods, the resulting module structures are inconsistent. Additionally, when using relax.build() after applying the transformation directly, a build failure occurs with the following internal error:

InternalError: Check failed: (slot->value_computed) is false: PrimExpr T.int64(4) * n * m in function I.GlobalVar("main") has not been computed.

Expected behavior

The module structures generated by applying relax.transform.InlinePrivateFunctions() using Sequential or direct application should be consistent, and the module should compile successfully without any internal errors when using relax.build().

Actual behavior

  • The structures of the module differ between the two methods of applying the transformation.
  • When applying the transformation and using relax.build(), an internal error occurs, indicating that a computation involving n and m is not computed.

Steps to reproduce

import tvm
from tvm import relax
import numpy as np
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def add(var_x2: T.handle, var_y2: T.handle, var_T_add: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        n, m = T.int64(), T.int64()
        x2 = T.match_buffer(var_x2, (n, m))
        y2 = T.match_buffer(var_y2, (n, m))
        T_add = T.match_buffer(var_T_add, (n, m))
        for ax0, ax1 in T.grid(n, m):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(x2[v_ax0, v_ax1], y2[v_ax0, v_ax1])
                T.writes(T_add[v_ax0, v_ax1])
                T_add[v_ax0, v_ax1] = x2[v_ax0, v_ax1] + y2[v_ax0, v_ax1]

    @R.function(private=True)
    def main_inner(x2: R.Tensor(("n", "m"), dtype="float32"), y2: R.Tensor(("n", "m"), dtype="float32")) -> R.Tensor(("n", "m"), dtype="float32"):
        n = T.int64()
        m = T.int64()
        cls = Module
        sum_inner = R.call_tir(cls.add, (x2, y2), out_sinfo=R.Tensor((n, m), dtype="float32"))
        return sum_inner

    @R.function
    def main(x1: R.Tensor((10, 5), dtype="float32"), y1: R.Tensor((10, 5), dtype="float32")) -> R.Tensor((10, 5), dtype="float32"):
        cls = Module
        sum_main: R.Tensor((10, 5), dtype="float32") = cls.main_inner(x1, y1)
        return sum_main

mod = Module

mod_seq = tvm.transform.Sequential([relax.transform.InlinePrivateFunctions(),])(mod)
mod = relax.transform.InlinePrivateFunctions()(mod)
#tvm.ir.assert_structural_equal(mod_seq, mod)
print(mod["main"].body.blocks[0].bindings[0].value.sinfo_args[0].shape.values[0])
with tvm.transform.PassContext(opt_level=4):
    ex = relax.build(mod, target='llvm')

Could you please help confirm if this is a bug in TVM or an issue with my usage?

Thrsu avatar Oct 21 '24 12:10 Thrsu