[WIP][DO NOT MERGE] feat : add Python pass for DRR
PR Category
User Experience
PR Types
New features
Description
这是DRR Pass的初版,一份可能的使用样例:
import paddle
import paddle.pir as pir
def matmul_transpose_fuse_pass():
python_ctx = pir.DrrPatternContext()
def pattern_pass():
python_pat = python_ctx.SourcePattern()
# 这个地方是不是把命名优化下,底层再做一个local name 转 pd_op.name的操作好一点呢
matmul_op = python_pat.Op("pd_op.matmul",
{{"transpose_x", python_pat.Attr("transpose_x")},
{"transpose_y", python_pat.Attr("transpose_y")}})
transpose_op = python_pat.Op("pd_op.transpose",
{{"perm", pat.Attr("perm")}})
# 把过op的形式统一成op(output, input)的格式,降低开发量
python_pat.Tensor("x_transpose_out") = transpose_op(python_pat.Tensor("x"))
python_pat.Tensor("matmul_op_out") =
matmul_op(python_pat.Tensor("x_transpose_out"), python_pat.Tensor("y"))
# 感觉这里只能迫不得已使用回调函数了。。。
python_pat.AddConstraint(cons_function)
def cons_function(match_ctx):
x_shape = pir.GetShapeFromValue(match_ctx.Tensor("x"))
y_shape = pir.GetShapeFromValue(match_ctx.Tensor("y"))
if (len(x_shape) < 2 or len(y_shape) < 2):
return false
perm = match_ctx.Attr("perm")
perm_size = len(perm)
for i in range(perm_size - 2):
if perm[i] != i:
return False
if (perm[perm_size - 1] != perm_size - 2) and (perm[perm_size - 2] != perm_size - 1):
return False
return True
def result_pass():
python_res = pat.ResultPattern()
def res_transpose_x(match_ctx):
return not match_ctx.Attr("transpose_x")
transpose_x = res.ComputeAttr(res_transpose_x)
def res_transpose_y(match_ctx):
return not match_ctx.Attr("transpose_y")
transpose_y = res.ComputeAttr(res_transpose_y)
fused_matmul_transpose_op =
res.Op("pd_op.matmul",
{{"transpose_x", transpose_x}, {"transpose_y", transpose_y}});
res.Tensor("matmul_op_out") =
fused_matmul_transpose_op(res.Tensor("x"), res.Tensor("y"));
pattern_pass()
result_pass()
return python_ctx
matmul_transpose_fuse_ctx = matmul_transpose_fuse_pass()
pass_register.add_pass("matmul_transpose_fuse_pass", matmul_transpose_fuse_ctx)
你的PR提交成功,感谢你对开源项目的贡献! 请关注后续CI自动化测试结果,详情请参考Paddle-CI手册。 Your PR has been submitted. Thanks for your contribution! Please wait for the result of CI firstly. See Paddle CI Manual for details.
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.
You have signed the CLA already but the status is still pending? Let us recheck it.
Sorry to inform you that f843ba3's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.
Sorry to inform you that 3ec0eda's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.
cc: @yuanlehome
文档链接贴一下~
文档链接贴一下~
好的 我贴一个rfcs:https://github.com/PaddlePaddle/community/pull/1026
TODO:1.支持一个auto drr pass内多个auto drr pattern同时注册 2.支持python设备返回类型进行注册