tvm
tvm copied to clipboard
[Relax][Bug] Segmentation fault when using the MergeCompositeFunctions transform
Actual behavior
Segmentation fault (core dumped)
Environment
TVM: 0.17.dev0 OS: Ubuntu20.04
Steps to reproduce
import tvm
from tvm import relax
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def relu(x11: T.Buffer((T.int64(10),), "float32"), compute: T.Buffer((T.int64(10),), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i0 in range(T.int64(10)):
with T.block("compute"):
v_i0 = T.axis.spatial(T.int64(10), i0)
T.reads(x11[v_i0])
T.writes(compute[v_i0])
compute[v_i0] = T.max(x11[v_i0], T.float32(0))
@R.function(private=True)
def fused_relax_nn_gelu(x21: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"):
R.func_attr({"Composite": "compiler_A.gelu", "Primitive": 1})
cls = Module
with R.dataflow():
gv3 = R.nn.gelu(x21)
R.output(gv3)
return gv3
@R.function(private=True)
def fused_relax_nn_relu(x11: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"):
R.func_attr({"Composite": "compiler_A.relu", "Primitive": 1})
cls = Module
with R.dataflow():
# gv2 = R.call_tir(cls.relu, (x11,), out_sinfo=R.Tensor((10,), dtype="float32"))
gv2 = R.nn.relu(x11)
R.output(gv2)
return gv2
@R.function
def main(x1: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"):
cls = Module
with R.dataflow():
lv1: R.Tensor((10,), dtype="float32") = cls.fused_relax_nn_relu(x1)
lv2 = R.call_tir(cls.relu, (lv1,), out_sinfo=R.Tensor((10,), dtype="float32"))
lv3: R.Tensor((10,), dtype="float32") = cls.fused_relax_nn_gelu(lv2)
R.output(lv3)
return lv3
mod = Module
mod.show()
mod = relax.transform.MergeCompositeFunctions()(mod) #seg fault
Triage
- needs-triage
cc @junrushao