relax
relax copied to clipboard
[Discuss] emit_te sugar in TVMScript
TVMScript is very useful to create Relax IRModule, especially for unit tests. However, since Relax does not have its own operator set yet, creating an IRModule with several well known operators is a hassle using TVMScript.
In such cases, one would have to generate the IRModule either using
- BlockBuilder(see examples), or
- generate the IRModule using BlockBuilder, then print TVMScript and copy it.
In both cases, we use the BlockBuilder because its emit_te interface can generate TIR implementations corresponding to TE compute.
I propose, we introduce emit_te
sugar in TVMScript as well which would internally find the relevant op strategy and generate the corresponding TIR primfunc. The signature could look something like below.
lv = R.emit_te(<te_compute>, input_arg0, input_arg1, ..., attrs=<dictionary of attributes>)
example: lv = R.emit_te(topi.add, x, y, attrs={'my_op_kind': 'addition operation', ...})
This will allow us to replace example test cases below with TVMScript which is easier to read.
using BlockBuilder
def test_fuse_simple():
def before():
bb = relax.BlockBuilder()
x = relax.Var("x", [10, 20], relax.DynTensorType(2, "float32"))
with bb.function("main", [x]):
with bb.dataflow():
lv0 = bb.emit_te(topi.add, x, relax.const(1, "float32"))
lv1 = bb.emit_te(topi.exp, lv0)
gv = bb.emit_output(bb.call_te(topi.squeeze, lv1))
bb.emit_func_output(gv)
return bb.get()
def expected():
bb = relax.BlockBuilder()
x = relax.Var("x", [10, 20], relax.DynTensorType(2, "float32"))
p0 = relax.Var("p0", (), relax.DynTensorType(0, "float32"))
with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": 1}):
with bb.dataflow():
lv0 = bb.emit_te(topi.add, x, p0)
lv1 = bb.emit_te(topi.exp, lv0)
gv = bb.emit_output(bb.call_te(topi.squeeze, lv1))
bb.emit_func_output(gv)
fused_add_exp_squeeze = bb.get().get_global_var("fused_add_exp_squeeze")
x = relax.Var("x", [10, 20], relax.DynTensorType(2, "float32"))
with bb.function("main", [x]):
with bb.dataflow():
gv = bb.emit_output(
relax.Call(fused_add_exp_squeeze, [x, relax.const(1, "float32")])
)
bb.emit_func_output(gv)
return bb.get()
_check(before(), expected())
using TVMScript with emit_te sugar.
def test_fuse_simple():
@tvm.script.ir_module
class Before:
@R.function
def main(x: Tensor((10, 20), "float32")) -> Tensor(None, "float32", ndim = 2):
# block 0
with R.dataflow():
lv0 = R.emit_te(topi.add, x, relax.const(1, "float32"))
lv1 = R.emit_te(topi.exp, lv0)
gv = R.emit_te(topi.squeeze, lv1)
R.output(gv)
return gv
@tvm.script.ir_module
class Expected:
@R.function
def fused_add_exp_squeeze(x: Tensor((10, 20), "float32"), p0: Tensor((), "float32")) -> Tensor(None, "float32", ndim = 2):
with R.dataflow():
lv0 = R.emit_te(topi.add, x, relax.const(1, "float32"))
lv1 = R.emit_te(topi.exp, lv0)
gv = R.emit_te(topi.squeeze, lv1)
R.output(gv)
return gv
@R.function
def main(x1: Tensor((10, 20), "float32")) -> Tensor(None, "float32", ndim = 2):
with R.dataflow():
gv1: Tensor((10, 20), "float32") = fused_add_exp_squeeze(x1, 1)
R.output(gv1)
return gv1
_check(Before, Expected)
This seems to be related to meta-programming , cc @yelite @cyx-6 @junrushao1994
I totally agree it is useful. It can be done when the new parser is ready :)
Great, looking forward to it! I'll keep the issue open just for tracking purposes.
support was merged in tvm/unity https://github.com/apache/tvm/pull/14123