Paddle icon indicating copy to clipboard operation
Paddle copied to clipboard

[WIP][DO NOT MERGE] feat : add Python pass for DRR

Open FlamingoPg opened this issue 1 year ago • 3 comments

PR Category

User Experience

PR Types

New features

Description

这是DRR Pass的初版,一份可能的使用样例:

import paddle
import paddle.pir as pir

def matmul_transpose_fuse_pass():
    python_ctx = pir.DrrPatternContext()

    def pattern_pass():
        python_pat = python_ctx.SourcePattern()
        # 这个地方是不是把命名优化下,底层再做一个local name 转 pd_op.name的操作好一点呢
        matmul_op = python_pat.Op("pd_op.matmul", 
                                  {{"transpose_x", python_pat.Attr("transpose_x")},
                                   {"transpose_y", python_pat.Attr("transpose_y")}})
        transpose_op = python_pat.Op("pd_op.transpose",
                                    {{"perm", pat.Attr("perm")}})

        # 把过op的形式统一成op(output, input)的格式,降低开发量
        python_pat.Tensor("x_transpose_out") = transpose_op(python_pat.Tensor("x"))
        python_pat.Tensor("matmul_op_out") = 
                matmul_op(python_pat.Tensor("x_transpose_out"), python_pat.Tensor("y"))

        # 感觉这里只能迫不得已使用回调函数了。。。
        python_pat.AddConstraint(cons_function)

    def cons_function(match_ctx):
        x_shape = pir.GetShapeFromValue(match_ctx.Tensor("x"))
        y_shape = pir.GetShapeFromValue(match_ctx.Tensor("y"))
        if (len(x_shape) < 2 or len(y_shape) < 2):
            return false
        perm = match_ctx.Attr("perm")
        perm_size = len(perm)
        for i in range(perm_size - 2):
            if perm[i] != i:
                return False
        if (perm[perm_size - 1] != perm_size - 2) and (perm[perm_size - 2] != perm_size - 1):
            return False
        return True

    def result_pass():
        python_res = pat.ResultPattern()

        def res_transpose_x(match_ctx):
            return not match_ctx.Attr("transpose_x")

        transpose_x = res.ComputeAttr(res_transpose_x)

        def res_transpose_y(match_ctx):
            return not match_ctx.Attr("transpose_y")

        transpose_y = res.ComputeAttr(res_transpose_y)
        
        fused_matmul_transpose_op =
        res.Op("pd_op.matmul",
               {{"transpose_x", transpose_x}, {"transpose_y", transpose_y}});
        res.Tensor("matmul_op_out") =
            fused_matmul_transpose_op(res.Tensor("x"), res.Tensor("y"));
    
    pattern_pass()
    result_pass()

    return python_ctx

matmul_transpose_fuse_ctx = matmul_transpose_fuse_pass()
pass_register.add_pass("matmul_transpose_fuse_pass", matmul_transpose_fuse_ctx)

FlamingoPg avatar Oct 17 '24 02:10 FlamingoPg

你的PR提交成功,感谢你对开源项目的贡献! 请关注后续CI自动化测试结果,详情请参考Paddle-CI手册。 Your PR has been submitted. Thanks for your contribution! Please wait for the result of CI firstly. See Paddle CI Manual for details.

paddle-bot[bot] avatar Oct 17 '24 02:10 paddle-bot[bot]

CLA assistant check
All committers have signed the CLA.

CLAassistant avatar Oct 17 '24 02:10 CLAassistant

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.
You have signed the CLA already but the status is still pending? Let us recheck it.

CLAassistant avatar Oct 17 '24 02:10 CLAassistant

Sorry to inform you that f843ba3's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

paddle-ci-bot[bot] avatar Oct 30 '24 03:10 paddle-ci-bot[bot]

Sorry to inform you that 3ec0eda's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

paddle-ci-bot[bot] avatar Nov 09 '24 03:11 paddle-ci-bot[bot]

cc: @yuanlehome

FlamingoPg avatar Nov 14 '24 16:11 FlamingoPg

文档链接贴一下~

yuanlehome avatar Dec 11 '24 09:12 yuanlehome

文档链接贴一下~

好的 我贴一个rfcs:https://github.com/PaddlePaddle/community/pull/1026

FlamingoPg avatar Dec 11 '24 09:12 FlamingoPg

TODO:1.支持一个auto drr pass内多个auto drr pattern同时注册 2.支持python设备返回类型进行注册

FlamingoPg avatar Dec 11 '24 09:12 FlamingoPg