relax icon indicating copy to clipboard operation
relax copied to clipboard

[Discuss] emit_te sugar in TVMScript

Open psrivas2 opened this issue 2 years ago • 3 comments

TVMScript is very useful to create Relax IRModule, especially for unit tests. However, since Relax does not have its own operator set yet, creating an IRModule with several well known operators is a hassle using TVMScript.

In such cases, one would have to generate the IRModule either using

In both cases, we use the BlockBuilder because its emit_te interface can generate TIR implementations corresponding to TE compute.

I propose, we introduce emit_te sugar in TVMScript as well which would internally find the relevant op strategy and generate the corresponding TIR primfunc. The signature could look something like below.

lv = R.emit_te(<te_compute>, input_arg0, input_arg1, ..., attrs=<dictionary of attributes>)

example: lv = R.emit_te(topi.add, x, y, attrs={'my_op_kind': 'addition operation', ...})

This will allow us to replace example test cases below with TVMScript which is easier to read.

using BlockBuilder

def test_fuse_simple():
    def before():
        bb = relax.BlockBuilder()
        x = relax.Var("x", [10, 20], relax.DynTensorType(2, "float32"))
        with bb.function("main", [x]):
            with bb.dataflow():
                lv0 = bb.emit_te(topi.add, x, relax.const(1, "float32"))
                lv1 = bb.emit_te(topi.exp, lv0)
                gv = bb.emit_output(bb.call_te(topi.squeeze, lv1))
            bb.emit_func_output(gv)

        return bb.get()

    def expected():
        bb = relax.BlockBuilder()
        x = relax.Var("x", [10, 20], relax.DynTensorType(2, "float32"))
        p0 = relax.Var("p0", (), relax.DynTensorType(0, "float32"))

        with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": 1}):
            with bb.dataflow():
                lv0 = bb.emit_te(topi.add, x, p0)
                lv1 = bb.emit_te(topi.exp, lv0)
                gv = bb.emit_output(bb.call_te(topi.squeeze, lv1))
            bb.emit_func_output(gv)
        fused_add_exp_squeeze = bb.get().get_global_var("fused_add_exp_squeeze")

        x = relax.Var("x", [10, 20], relax.DynTensorType(2, "float32"))
        with bb.function("main", [x]):
            with bb.dataflow():
                gv = bb.emit_output(
                    relax.Call(fused_add_exp_squeeze, [x, relax.const(1, "float32")])
                )
            bb.emit_func_output(gv)

        return bb.get()

    _check(before(), expected())

using TVMScript with emit_te sugar.

def test_fuse_simple():
    @tvm.script.ir_module
    class Before:    
        @R.function
        def main(x: Tensor((10, 20), "float32")) -> Tensor(None, "float32", ndim = 2):
            # block 0
            with R.dataflow():
                lv0 = R.emit_te(topi.add, x, relax.const(1, "float32"))
                lv1 = R.emit_te(topi.exp, lv0)
                gv = R.emit_te(topi.squeeze, lv1)
                R.output(gv)
            return gv
        
    
    @tvm.script.ir_module
    class Expected:
        @R.function
        def fused_add_exp_squeeze(x: Tensor((10, 20), "float32"), p0: Tensor((), "float32")) -> Tensor(None, "float32", ndim = 2):
            with R.dataflow():
                lv0 = R.emit_te(topi.add, x, relax.const(1, "float32"))
                lv1 = R.emit_te(topi.exp, lv0)
                gv = R.emit_te(topi.squeeze, lv1)
                R.output(gv)
            return gv
        @R.function
        def main(x1: Tensor((10, 20), "float32")) -> Tensor(None, "float32", ndim = 2):
            with R.dataflow():
                gv1: Tensor((10, 20), "float32") = fused_add_exp_squeeze(x1, 1)
                R.output(gv1)
            return gv1

    _check(Before, Expected)

psrivas2 avatar Aug 01 '22 20:08 psrivas2

This seems to be related to meta-programming , cc @yelite @cyx-6 @junrushao1994

tqchen avatar Aug 01 '22 21:08 tqchen

I totally agree it is useful. It can be done when the new parser is ready :)

Hzfengsy avatar Aug 03 '22 14:08 Hzfengsy

Great, looking forward to it! I'll keep the issue open just for tracking purposes.

psrivas2 avatar Aug 03 '22 16:08 psrivas2

support was merged in tvm/unity https://github.com/apache/tvm/pull/14123

yongwww avatar Feb 27 '23 17:02 yongwww