Paddle icon indicating copy to clipboard operation
Paddle copied to clipboard

【开源任务】算子切分推导规则开发,支持更多模型使用自动并行,简化更多用户的分布式开发成本

Open Glencsa opened this issue 8 months ago • 10 comments

一、任务背景

1.1 自动并行原理,自动并行与算子切分推导规则的关系

飞桨3.0发布了动静统一的自动并行功能,目的是简化分布式的开发,仅需要用户在神经网络中某些关键位置配置或标注分布式状态,深度学习框架就自动推断出网络中剩余其它部分的分布式状态并执行。自动推断在保证计算结果正确性的前提下,需要实现最优的通信和计算。 具体在框架实现中,当使用自动并行的方式执行1个神经网络时,网络中的输入数据(Tensor)和训练权重都会按照用户的配置和标注,初始化出分布式状态。分布式状态主要包括2部分:设备组织(通过paddle.distributed.ProcessMesh构造) 和 设备对数据(Tensor)的切分状态(通过paddle.distributed.Placement构造)。计算时,会按照神经网络定义执行某些算子的计算,即对于神经网络中的每个算子,只要定义1个规则(下面称 切分推导规则),就可以自动推断出此算子计算过程中需要的通信,进而逐个算子自动推断出整个网络所需要的通信,从而实现最优的通信和计算。 切分推导规则和算子计算逻辑强相关,其实现需要根据算子的计算逻辑、设备对输入数据(Tensor)的切分状态 来推断确定 最优的输入/输出数据(Tensor)的切分状态。如果某个输入/输出的切分状态 和 推断出的最优切分状态 不一样,则框架就能自动推断出在相应位置上所需要的通信。 Image 因此,为了让更多模型能够使用自动并行,简化更多用户的分布式开发成本,我们需要开发每个算子的切分推导规则。

1.2 算子切分推导规则介绍

自动并行中用户只标记了组网中部分 Tensors (Op)的切分状态(DistAttr),模型组网中仅有部分Tensors 有分布式属性(DistAttr)。在自动并行实际执行过前,模型组网中的所有Tensors(Ops) 都需要有一个确定的切分状态,每个Local 设备(进程)需要根据分布式属性信息判断在执行过程中当前设备(进程)需要的通信和切分操作。 算子切分推导规则的目标就是在利用该算子进行计算的同时,根据输入组网中的部分切分状态,推导补全整个组网的切分状态。

理想情况下,每个Op 都会有一个专门的切分推导规则。

自动并行中,每个算子的执行逻辑如下:

Image

数据的切分信息分为3种:

  • shard:在指定的张量维度上对张量进行切分。
  • replicate:跨设备复制tensor,每个rank得到完全相同的tensor
  • partial:一种张量,在不同设备上具有相同的形状,但在每个设备上只有部分值。 它可以进一步规约操作(即sum/min/max)以获得分布式张量。 这通常用作中间表示。

合法切分状态的推导

  • 对于一个孤立的 Tensor,我们可以随意设置它的在集群中的切分状态。 但是对于一个算子其输入输出Tensor 的切分状态不能是任意的。
  • 基于算子自身的运算逻辑,给定一个输入(输出)的切分切状态,其输出(输入) 合法的切分状态是一个有限的集合。 基于用户部分切分标记,如何推导合法的切分状态。
  • 合法的定义为:Tensor 切分状态(shape,partial)满足Op 的运算要求,并能获得正确的(local)计算结果。

Matmul算子切分推导规则举例

import paddle.distributed as dist
mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]], dim_names=["x", "y"])
X W Y = XW
[Replicate, Replicate] [Replicate, Replicate] [Replicate, Replicate]
[Replicate, Replicate] [Replicate, 'x'] [Replicate, 'x']
['y', Replicate] [Replicate, Replicate] ['y', Replicate]
[Replicate, 'x'] ['y', Replicate] [Replicate, Replicate]
['y', Replicate] [Replicate, 'x'] ['y', 'x']
['y', 'x'] ['x', Replicate] ['y', Replicate]

不同算子的分布式规则只和算子逻辑相关。

1.3 算子切分推导规则开发原则

具体开发规则可以参考切分推导规则参考文档

半自动框架为步骤中的公共逻辑提供 公共 utils 函数,在Paddle/paddle/phi/infermeta/spmd_rules目录下,开发者只需要实现与 Op 自身运算法则相关的逻辑。 Image

开发者需要在该目录下创建算子同名的cpp文件(如:argmax.cc, argmax.h),在文件当中为该算子开发切分推导规则(若该算子有对应的反向算子,则要求在该文件中连同该反向算子的切分推导规则一同开发),开发完成以后,需在该文件夹中的rule.cc文件当中完成算子切分推导规则注册,即可完成对该算子的切分推导规则开发。此外,开发者需要在 Paddle/test/cpp/auto_parallel 文件夹下增加该算子切分推导规则的单元测试,完成新增代码的单测覆盖,并测试该算子切分推导规则的正确性。

二、任务详情,需要开发的算子列表

本期需要增加切分推导规则的算子如下,整体进展:

序号 算子名称 队伍名称/状态/PR 难度
1 topk @ooooo-create #72499
0.5×⭐️
2 cummax @ooooo-create #72720
0.5×⭐️
3 cummin @ooooo-create #72720
0.5×⭐️
4 batch_norm @Glencsa #72918
0.5×⭐️
5 mean_all @ooooo-create #72479
0.5×⭐️
6 unique @ooooo-create #72824
0.5×⭐️
7 expand_as @ooooo-create #72845 #73107
0.5×⭐️
8 log_softmax @ooooo-create #72720
0.5×⭐️
9 group_norm @Glencsa #72946
0.5×⭐️
10 index_select @ooooo-create #72727
0.5×⭐️
11 instance_norm @Glencsa #72938
0.5×⭐️
12 label_smooth @ooooo-create #72845
0.5×⭐️
13 sync_batch_norm @Glencsa #72918
0.5×⭐️
14 roll @ooooo-create #72740
0.5×⭐️
15 index_put @Juggler-YAN #73155
@ttuuuuyyyj
0.5×⭐️
16 depthwise_con2d @NKNaN #73134
17 conv2d_transpose @NKNaN #73188
18 conv3d @NKNaN #72882
19 roi_align @ooooo-create #72925

⭐️ 提交PR 模版 ⭐️:

// ------- PR 标题 --------

[Auto Parallel] Add spmd rule No.xxx for xxx and xxx_grad ops.  

// ------- PR 内容 --------

PR Category
Auto Parallel

PR types
New features

Description
为xxx和xxx_grad算子增加切分推导规则。

三、参考指南

建议

  • 开发者需要先看懂Paddle/paddle/phi/infermeta/spmd_rules目录下的utils.cc基础函数的一些使用,以及已算子命名的cpp文件(argmax.cc, argmax.h)的代码逻辑,有助于帮助开发者对新算子切分推导规则开发的快速入门。
  • 开发者在写单测代码时,应该考虑算子切分推导尽可能多的切分情况。

题目讲解见录屏文件:https://meeting.tencent.com/crm/l59EWmRZc4 (00:00:00~00:15:30)

看板信息

任务方向 任务数量 提交作品 / 任务认领 提交率 完成 完成率
算子切分推导规则开发 19 19 / 19 100.0% 13 68.42%

统计信息

排名不分先后 @ooooo-create (11) @Glencsa (2)

Glencsa avatar Apr 22 '25 12:04 Glencsa

【报名】:1、5

ooooo-create avatar Apr 25 '25 06:04 ooooo-create

【报名】:2、3、6

ooooo-create avatar May 14 '25 04:05 ooooo-create

【报名】:15

ghost avatar May 21 '25 14:05 ghost

【报名】:4、9

Glencsa avatar May 22 '25 02:05 Glencsa

【报名】:19

ooooo-create avatar May 25 '25 08:05 ooooo-create

【报名】:16,17

NKNaN avatar May 25 '25 09:05 NKNaN

【报名】:11

Glencsa avatar May 25 '25 16:05 Glencsa

【报名】:13

Glencsa avatar May 27 '25 08:05 Glencsa

【报名】:13

No.13(sync_batch_norm) and No.4(batch_norm) are finished in the same PR(PR#72918)

Glencsa avatar Jun 04 '25 02:06 Glencsa

【报名】:15

ttuuuuyyyj avatar Jun 05 '25 11:06 ttuuuuyyyj

【报名】:15

Glencsa avatar Jun 18 '25 06:06 Glencsa

由于pr #73233 即将合入, 此PR合入后导致算子切分推导规则的开发发生一些变化,影响到开发中或已经完成但未合入的pr,官网的正式文档在更新中,这里提前告知一下。

pr #73233 主要升级点及原因: 将dims_mapping 从[vector] (std::vector<int64_t>) 升级到了[vector of vector] (std::vector<std::vector<int64_t>>) 数据类型,index同样表示tensor dim, value std::vector<int64_t>表示0个到多个mesh_dim。支持了多个mesh dim切分同一个tensor dim。value的std::vector<int64_t> size为1时,和原有的情况兼容;为0时,表示Replicate切分状态。

带给spmd切分推导规则的影响: 1.对于不支持多个mesh dim 切分同一 tensor dim的op,可以按照std::vector<int64_t>开发spmd规则,仍通过TensorDistAttr的dims_mapping()成员函数获得dims_mapping。也可以按照std::vector<std::vector<int64_t>>开发spmd规则,通过TensorDistAttr的multi_dims_mapping()成员函数获得升级后的dims_mapping。两种写法仅仅在数据结构上不同,不影响spmd本身的逻辑,也不会导致效果上的差别。

2.对于需要支持多mesh dim切分同一个tensor dim的情况,只能通过multi_dims_mapping获得dims_mapping定义。原有通过TensorDistAttr的dims_mapping()成员函数获得dims_mapping(std::vector<int64_t>类型)只能支持一个tensor dim切分一个tensor dim。无法表示多mesh dim切分同一个tensor dim的情况。

3.TensorDistAttr 重载了set_dims_mapping方法,void set_dims_mapping(const std::vector<int64_t>& dims_mapping);void set_dims_mapping(const std::vector<std::vector<int64_t>>& dims_mapping);。原有的dist_attr.set_dims_mapping({-1});的写法会编译报错,需要显式指明{-1}的类型。可以按照 dist_attr.set_dims_mapping(std::vector<int64_t>{-1});修改。

4.对于开发了多mesh dim切分同一tensor dim的op,如reshape,可单独写该op的单测检查正确性。若下游op不支持,将报错,报错提示为"There are %d mesh dim sharded on tensor dim %d, you should call \"multi_dims_mapping()\"

如有其他未能覆盖问题,可以直接回复或者联系[email protected]

liufengwei0103 avatar Jun 20 '25 03:06 liufengwei0103

【开源任务】算子切分推导规则开发,支持更多模型使用自动并行,简化更多用户的分布式开发成本 已全部完成,感谢参与的小伙伴们!

排名不分先后 @ooooo-create (11) @Glencsa (5) @NKNaN (3)

欢迎继续参与 https://github.com/orgs/PaddlePaddle/projects/7 !

luotao1 avatar Jul 09 '25 13:07 luotao1