tvm
tvm copied to clipboard
[Bug] AssertionError in the LazyTransformParams
Actual behavior
Traceback (most recent call last):
File "/share_container/optfuzz/res/bugs/assert_lazy.py", line 52, in <module>
mod = relax.transform.LazyTransformParams()(mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm-lunder/python/tvm/ir/transform.py", line 238, in __call__
return _ffi_transform_api.RunPass(self, mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm-lunder/python/tvm/_ffi/_ctypes/packed_func.py", line 240, in __call__
raise_last_ffi_error()
File "/software/tvm-lunder/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
raise py_err
File "/software/tvm-lunder/python/tvm/ir/transform.py", line 307, in _pass_func
return inst.transform_module(mod, ctx)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm-lunder/python/tvm/relax/transform/lazy_transform_params.py", line 396, in transform_module
func = lazy_mutator.transform(func)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm-lunder/python/tvm/relax/transform/lazy_transform_params.py", line 151, in transform
forward_collector.visit_expr(func)
File "/software/tvm-lunder/python/tvm/relax/expr_functor.py", line 346, in visit_expr
return _ffi_api.PyExprVisitorVisitExpr(self, expr) # type: ignore
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm-lunder/python/tvm/meta_schedule/utils.py", line 76, in method
return getattr(inst, name)(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm-lunder/python/tvm/relax/transform/lazy_transform_params.py", line 59, in visit_var_binding_
assert isinstance(binding.value, relax.Tuple)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError
Steps to reproduce
import tvm
from tvm import relax
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def add(C: T.Buffer((T.int64(16), T.int64(16)), "float32"), B: T.Buffer((T.int64(16), T.int64(16)), "float32"), T_add: T.Buffer((T.int64(16), T.int64(16)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(16), T.int64(16)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(C[v_ax0, v_ax1], B[v_ax0, v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = C[v_ax0, v_ax1] + B[v_ax0, v_ax1]
@T.prim_func(private=True)
def multiply(A: T.Buffer((T.int64(16), T.int64(16)), "float32"), T_multiply: T.Buffer((T.int64(16), T.int64(16)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(16), T.int64(16)):
with T.block("T_multiply"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v_ax0, v_ax1])
T.writes(T_multiply[v_ax0, v_ax1])
T_multiply[v_ax0, v_ax1] = A[v_ax0, v_ax1] * T.float32(2)
@R.function
def transform_params(A: R.Tensor((16, 16), dtype="float32"), B: R.Tensor((16, 16), dtype="float32")) -> R.Tuple(R.Tensor((16, 16), dtype="float32"), R.Tensor((16, 16), dtype="float32")):
cls = Module
C = R.call_tir(cls.multiply, (A,), out_sinfo=R.Tensor((16, 16), dtype="float32"))
D = R.call_tir(cls.add, (C, B), out_sinfo=R.Tensor((16, 16), dtype="float32"))
para0: R.Tensor((16, 16), dtype="float32") = B
para1: R.Tensor((16, 16), dtype="float32") = B
res: R.Tuple(R.Tensor((16, 16), dtype="float32"), R.Tensor((16, 16), dtype="float32")) = cls.transform_params_2(para0, para1)
return res
@R.function
def transform_params_2(A: R.Tensor((16, 16), dtype="float32"), B: R.Tensor((16, 16), dtype="float32")) -> R.Tuple(R.Tensor((16, 16), dtype="float32"), R.Tensor((16, 16), dtype="float32")):
cls = Module
C = R.call_tir(cls.multiply, (A,), out_sinfo=R.Tensor((16, 16), dtype="float32"))
D = R.call_tir(cls.add, (C, B), out_sinfo=R.Tensor((16, 16), dtype="float32"))
return (D, D)
mod = Module
mod = relax.transform.LazyTransformParams()(mod) # crash here
cc @Lunderberg @junrushao
I assume there's a typo and that cls.transform_params_7 is supposed to be cls.transform_params_2. With that, I can reproduce your error.
It looks like this is a limitation in the LazyTransformParams, that it expects a tuple of outputs to be produced within the function, rather than being a return value from a subroutine. There's a couple of options on how this can be worked around:
-
Use
LazyGetInputandLazySetOutput, which are intended to replaceLazyTransformParams. These add callback arguments rather than performing lazy-loading through the global"get_item"or"set_item"functions, and don't have the same limitations for tuple outputs. (The long-term plan is to replace the implementation ofLazyTransformParamswith one that callsLazyGetInputandLazySetOutputinternally.) -
If you only need
transform_paramsand nottransform_params_2, you could marktransform_params_2as private, then userelax.transform.InlinePrivateFunctions(). This would move the(D,D)tuple intotransform_params, working around the current limitation ofLazyTransformParams.
@Lunderberg Thanks for your investigation. Such information help deeply undestand the usage of different transforms. Due to the incomplete documentation of TVM, understand the usage of each transform based on the source code has some difficulty. Your explanation help me a lot! Thanks again.