libai
libai copied to clipboard
[Bug][MT5] Throughput is unexpected
32m[10/19 15:36:23 lb.utils.events] [0 m eta: 7986 days, 7:30:37 iteration: 99/621340880 consumed_samples: 200 total_loss: 9.545 time: 1.1185 s/iter data_time: 0.0151 s/iter total_throughput: 1.79 samples/s lr: 1.02e-08
t5 单机4卡测试
-
机器:oneflow-25 单机4卡
-
oneflow master https://github.com/Oneflow-Inc/oneflow/commit/93d19f3be52632cccc875c8e46011eced14249a0
-
libai main https://github.com/Oneflow-Inc/libai/commit/e9ca4087cb35b3ad268534ee60456db689e36063
-
用例:
t5_nl12_nah12_hs768_FP16_actrue_mp2_pp1_mb32_gb512_1n4g
zero_stage=2
t5 2机4卡测试
-
机器:oneflow-25 oneflow-28 2机一共8卡
-
oneflow master https://github.com/Oneflow-Inc/oneflow/commit/93d19f3be52632cccc875c8e46011eced14249a0
-
libai main https://github.com/Oneflow-Inc/libai/commit/e9ca4087cb35b3ad268534ee60456db689e36063
-
用例:
t5_nl12_nah12_hs768_FP16_actrue_mp2_pp1_mb16_gb512_2n4g
zero_stage=2
这里比较明显的问题是,我们 4 卡 2-D 并行是超过 Megatron 的,但是 两机8卡的吞吐比 单机四卡的还慢。而 Megatron 是一个线性的加速比。
这里有点问题,libai.models.T5Model是megatron的版本,IDEA需要的是huggingface版本的T5,也就是libai的projects下的T5(projects/T5是交付项目),这两个模型结构有区别,已经让yongning增加一份projects/T5的测试了,交付之前也是用projects/T5来和libai.model.T5Model来测的纯数据并行:here,两个模型不一样,感觉不能简单地去比较和megatron的性能,因为megatron实现的不是huggingface版本的T5
两个T5的区别总结:
-
layernorm对应的算子不同(mt5用c++拼接算子:RMSLayernorm)
-
decoder多一层embedding:https://github.com/Oneflow-Inc/libai/blob/b3c5ba2b90ae6debbebf8e9b96806327fb21c9c5/projects/T5/models/attention.py#L117-L120
-
dropout对应算子不同 (mt5使用的是:https://github.com/Oneflow-Inc/oneflow/pull/8693)
-
mt5(projects下的T5)的lm_head没有共享embedding的参数 (https://github.com/Oneflow-Inc/libai/blob/9a4af263756ff6a1c8abe73e9a51a29f0d8c0533/projects/T5/models/t5_model.py#L129-L134 )
-
mt5(projects下的T5)比t5(libai.models中的T5)少了position_embedding,但是mt5中的attention多出了position_bias的相关计算(https://github.com/Oneflow-Inc/libai/blob/e9ca4087cb35b3ad268534ee60456db689e36063/projects/T5/models/attention.py#L272 和 https://github.com/Oneflow-Inc/libai/blob/e9ca4087cb35b3ad268534ee60456db689e36063/projects/T5/models/attention.py#L320 )
-
mt5(projects下的T5)不包含任何bias. (Linear 和 LayerNorm)
-
mt5(projects下的T5)因为要对齐huggingface的版本,没有用到t5(libai.models中的T5)当中的一些优化的地方,比如scale_mask_softmax_fusion,(mt5: https://github.com/Oneflow-Inc/libai/blob/9a4af263756ff6a1c8abe73e9a51a29f0d8c0533/projects/T5/models/attention.py#L232-L244 t5: https://github.com/Oneflow-Inc/libai/blob/9a4af263756ff6a1c8abe73e9a51a29f0d8c0533/libai/layers/attention.py#L214-L250 )
-
mt5(projects下的T5)的MLP层比t5是多出一层Linear的(https://github.com/Oneflow-Inc/libai/blob/main/projects/T5/models/mlp.py )
-
[form chengcheng] Attention 里的 FuseMultiHeadAttention 这个优化, megatron 将原本正常语义下的 batch size 转置到了 第一维 , 这个从语义上是难以理解的,但是从性能上,可以只在 Transformer layer 之前做一次 transpose,内部的 matmul 可以使用 batch gemm 执行。 如果不这么做的话,需要在 每个 layer 内部都做 transpose ,单单这个优化就有 10% 的性能差距。 https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/model/transformer.py#L395
mt5里没用到scale_mask_softmax_fusion,所以是走了t5的att中的else分支
def att_mt5(attention_scores, attention_mask):
dropout = nn.Dropout(0)
attention_scores = flow.mul(attention_scores, attention_mask)
attention_scores = attention_scores - 10000.0 * (1 - attention_mask)
attention_weights = flow.softmax(attention_scores, dim=-1)
attention_weights = dropout(attention_weights)
return attention_weights
def att_t5(attention_scores, attention_mask, scale_mask_softmax_fusion=True, coeff=None, attention_dropout_prob=0, use_cache=False):
dropout = nn.Dropout(0)
if scale_mask_softmax_fusion:
if attn_mask_type == AttnMaskType.padding:
attention_mask = (
attention_mask.expand_as(attention_scores) if use_cache else attention_mask
)
attention_weights = flow._C.fused_scale_mask_softmax_dropout(
attention_scores,
attention_mask,
fill_value=-10000.0,
scale=coeff,
p=attention_dropout_prob,
)[0]
else:
if coeff is not None:
attention_scores *= coeff
attention_scores = flow.mul(attention_scores, attention_mask)
attention_scores = attention_scores - 10000.0 * (1 - attention_mask)
attention_weights = flow.softmax(attention_scores, dim=-1)
attention_weights = dropout(attention_weights)
return attention_weights
@chengtbf @strint @xyn1201
@xiezipeng-ML 这里说的hugging face版本的T5指的是 transformers 库的吗?如果是的话,直接支持transformers里面T5的oneflow后端之后,你觉得可以直接跑分布式训练吗?我上周移植了transformers的CLIP的infer,不知道训练会多多少东西。transformers的CLIP和t5应该会共用一些基础的模块吧。
@xiezipeng-ML 这里说的hugging face版本的T5指的是 transformers 库的吗?如果是的话,直接支持transformers里面T5的oneflow后端之后,你觉得可以直接跑分布式训练吗?我上周移植了transformers的CLIP的infer,不知道训练会多多少东西。transformers的CLIP和t5应该会共用一些基础的模块吧。
是的 transformers仓库,我slack请教你
@xyn1201 这个的 nsys 结果是不是还没有
- 刚刚分别跑了dp4_mp2_pp1和dp2_mp4_pp1的2机4卡测试
- dp4_mp2_pp1:吞吐是比较正常的
- dp2_mp4_pp1:这个是IDEA给的配置,跑的很慢,15分钟第一个iter都没有跑完,后面就没再等了。
然后列一下dp4_mp2_pp1这组配置的对比结果,libai的是今天新跑的,megatron用的前面comment里的数据,两个模型的参数对齐了,但是数据集用的不一样,这个麻烦 @xiezipeng-ML 给说明一下
projects/T5 单机4卡测试
-
机器:oneflow-28 单机4卡
-
oneflow master https://github.com/Oneflow-Inc/oneflow/commit/f97f09f1d9a8668c972a12f66d77aaa19b164635
-
libai test_t5_time https://github.com/Oneflow-Inc/libai/commit/0002b6637c92e19728cd26830494fa33ab68efc1
-
对比:
projects/T5 2机4卡测试
- 机器:oneflow-25 oneflow-28 2机一共8卡
- oneflow master https://github.com/Oneflow-Inc/oneflow/commit/f97f09f1d9a8668c972a12f66d77aaa19b164635
- libai test_t5_time https://github.com/Oneflow-Inc/libai/commit/0002b6637c92e19728cd26830494fa33ab68efc1
- 用例:
缺少了 Megatron 1n4d 2n4d 的 nsys,oneflow 1n4d nsys
然后列一下dp4_mp2_pp1这组配置的对比结果,libai的是今天新跑的,megatron用的前面comment里的数据,两个模型的参数对齐了,但是数据集用的不一样,这个麻烦 @xiezipeng-ML 给说明一下
昨晚在libai的main分支里把IDEA的dataset换成了megatron的dataset测了下,两个datasets吞吐是一样的
单卡 mb4_gb32 libai_nsys megatron_nsys @chengtbf
初步分析结论
之前两天的测试和本地测试受到 T5 (Megatron)和 MT5 (huggingface) 的区别,以及本地历史 LiBai 版本的影响,拖延了问题分析的进度。
目前的初步结论是 **2-D SBP 下, OneFlow T5 的 sbp infer 结果是不高效的,比 Megatron 多了几倍的通信开销,导致整体的吞吐慢了三倍。这个现象是随着 batch size 的增大而变得更差 **
两机分析
Megatron 2机 nsys 结果:
主要看两个指标, 单个 iter 的前后向总耗时,以及 kernel 占比:
- 单个 iter 的耗时是 357ms, 分为 fw encoder (78ms)+ fw decoder (24ms) + loss (9ms) + bw decoder (88ms) + bw encoder (150ms)
- 其中,nccl 通信基本上都是 allreduce 通信,占总的计算时长比为: 66% (49.8% + 12.8% + 3.4%)
LiBai MT5 2机 nsys 结果:
- 总耗时 1000ms,是 Megatron 的3倍,其中: fw encoder (360ms)+ fw decoder (94ms) + bw decoder (175ms)+ bw encoder (339ms)
- nccl 占比: 占大头的不是 allreduce,而是 send recv(应该是 sbp 推导到了 bad case,send recv 很不高效) send recv 46.4% (应该全部都是多于的) + allreduce 17.9 % + allgather 12.9 %
单机4卡分析
单机四卡的性能结果, oneflow 比 Megatron 快一些 : oneflow 518ms vs Megatron 716ms
同时 Megatron 的 iter 间调度的间隔很大。 还有不少的优化空间。 oneflow 的调度是比较完美的。
Megatron 1n4d
OneFlow 1n4d
OneFlow 的时间虽然比 Megatron 快,但是并不是最优的,仍有至少 12% 的冗余 send recv 通信,主要是在前向 fw encoder 部分包含大量的 send recv 通信。
单机单卡比较
OneFlow 比 Megatron 优势非常明显: OneFlow 88ms vs Megatron 127ms
结论
- OneFlow 单卡速度领先 Megatron
- 在 4 卡的 shape 下,2d sbp 的推导结果不是那么差,速度领先 Megatron
- 在 8 卡的 shape 下,2d sbp 的推导结果非常差,多了几倍的通信开销,速度比 Megatron 慢 三倍
SBP_INFER_RULE_TAG=2 和 自动并行 测试吞吐
- 机器:oneflow-25 oneflow-28 2机一共8卡
- oneflow master https://github.com/Oneflow-Inc/oneflow/commit/f97f09f1d9a8668c972a12f66d77aaa19b164635
- libai main https://github.com/Oneflow-Inc/libai/commit/e9ca4087cb35b3ad268534ee60456db689e36063
- libai吞吐数据
mt5_pretrain.py
mb16_gb512
dp4_mp2_pp1
zero_stage=2
https://github.com/Oneflow-Inc/oneflow/pull/9288 在第一档允许了自动并行与ZeRO共存,但是实际效果没有测试过。我在16上跑,OOM了,毕竟自动并行还没有考虑内存,有些慌。 不过我看了一下,大的weight的sbp基本都是 (S0, S1),并没有给出(B, B),所以功能上是符合预期了,就是有ZeRO,然后也有AutoParallel。
另外它的op很多,不算variable快5000个了,所以初始化cost的时候很慢,需要20分钟,我优化了一下,估计能压缩一半。 至于搜索算法就很快,半分钟就能出结果。
在测试自动并行的同时,建议先看下半自动推导下mt5的boxing里面,哪里的sbp不符合预期,然后加上一些to_global来控制一下。
带 nccl logical op 和 sbp 的 op graph log:https://oneflow-test.oss-cn-beijing.aliyuncs.com/mt5_test/2n4g_log/output.log
搜索下Operator
可以找到 op graph 的起点。
自动并行 2n4g 测试
- 机器:oneflow-25 oneflow-28 2机一共8卡
- oneflow feat-auto_parallel-ZeRO分支 https://github.com/Oneflow-Inc/oneflow/pull/9288/commits/54771bc917aa1b7509e758b7d5c1344ce00e7246 用这个分支 编译+自动并行的时间是半小时,确实加快了
- libai main https://github.com/Oneflow-Inc/libai/commit/e9ca4087cb35b3ad268534ee60456db689e36063
- 为了不OOM,调小了batch_size,做了一组对比
mb4_gb128
dp4_mp2_pp1
zero_stage=2
@Yipeng1994
看了一下 job/plan 发现了3个问题:
非预期的 SBP 变化
会导致后续一系列 SBP 都乱掉,从而导致有多余的 nccl logical boxing (是不是还有其他影响有多余的 nccl logical boxing 还要测试)
原因是 model.t5_model.encoder.layers.0.self_attention-reshape-29
这个 op 的位置,代码位置在这里,它消费了 query_key_value (broadcast_matmul) 的输出,该 broadcast_matmul 输出的 shape=(N,S,H), sbp=(S(0), S(2)),reshape 将其 reshape 成 (N,S,n,h),预期 sbp 不变仍然是 (S(0), S(2))。
但在2n4d情况下,该 reshape 前面被插入1个 System-NCCL-Logical-(*S)2(*S)-1867
,为 (S(0), S(2)) -> (S(0), S(1)) 的转换,导致后面一系列的 op 都不按照预期的 sbp 来推导,最终导致冗余 nccl logical boxing。而 1n4d 情况该处则正常。
1n4d job 片段
op {
name: "model.t5_model.encoder.layers.0.self_attention-reshape-29"
device_tag: "cuda"
ctrl_in_op_name: "model.t5_model.decoder.layers.0.self_attention-where-922"
scope_symbol_id: 728
stream_name_hint: "NCCL_COMPUTE_0"
loc: "Python Stack[-2]: \'forward\' at \'/home/xuyongning/zero_test/t5_test/libai/projects/T5/models/transformer_layer.py\': line 177; Python Stack[-1]: \'forward\' at \'/home/xuyongning/zero_test/t5_test/libai/projects/T5/models/attention.py\': line 194; ... 9 more"
user_conf {
op_type_name: "reshape"
input {
key: "in"
value {
s: "model.t5_model.encoder.layers.0.self_attention.query_key_value-broadcast_matmul-28/out_0"
}
}
output {
key: "out"
value {
s: "model.t5_model.encoder.layers.0.self_attention-reshape-29/out_0"
}
}
attr {
key: "shape"
value {
at_shape {
dim: 32
dim: 512
dim: 12
dim: 192
}
}
}
input_order: "in"
output_order: "out"
}
}
op_name2nd_sbp_signature_conf {
key: "model.t5_model.encoder.layers.0.self_attention-reshape-29"
value {
bn_in_op2nd_sbp {
key: "in_0"
value {
sbp_parallel {
split_parallel {
axis: 0
}
}
sbp_parallel {
split_parallel {
axis: 2
}
}
}
}
bn_in_op2nd_sbp {
key: "out_0"
value {
sbp_parallel {
split_parallel {
axis: 0
}
}
sbp_parallel {
split_parallel {
axis: 2
}
}
}
}
}
}
2n4d job 片段
op {
name: "model.t5_model.encoder.layers.0.self_attention-reshape-29"
device_tag: "cuda"
ctrl_in_op_name: "model.t5_model.decoder.layers.0.self_attention-reshape-903"
scope_symbol_id: 728
stream_name_hint: "NCCL_COMPUTE_0"
loc: "Python Stack[-2]: \'forward\' at \'/home/xuyongning/zero_test/t5_test/libai/projects/T5/models/transformer_layer.py\': line 177; Python Stack[-1]: \'forward\' at \'/home/xuyongning/zero_test/t5_test/libai/projects/T5/models/attention.py\': line 194; ... 9 more"
user_conf {
op_type_name: "reshape"
input {
key: "in"
value {
s: "System-NCCL-Logical-(*S)2(*S)-1867/out_0"
}
}
output {
key: "out"
value {
s: "model.t5_model.encoder.layers.0.self_attention-reshape-29/out_0"
}
}
attr {
key: "shape"
value {
at_shape {
dim: 64
dim: 512
dim: 12
dim: 192
}
}
}
input_order: "in"
output_order: "out"
}
}
op {
name: "System-NCCL-Logical-(*S)2(*S)-1867"
ctrl_in_op_name: "System-NCCL-Logical-(*S)2(*S)-1866"
ctrl_in_op_name: "model.t5_model.decoder.layers.0.self_attention-reshape-903"
scope_symbol_id: 17628
stream_name_hint: "NCCL_COMPUTE_0"
user_conf {
op_type_name: "_nccl_logical_2D_same_dim0_all2all"
input {
key: "in"
value {
s: "model.t5_model.encoder.layers.0.self_attention.query_key_value-broadcast_matmul-28/out_0"
}
}
output {
key: "out"
value {
s: "System-NCCL-Logical-(*S)2(*S)-1867/out_0"
}
}
attr {
key: "dst_reduced_nd_sbp"
value {
at_list_string {
val: "S(0)"
val: "S(1)"
}
}
}
attr {
key: "src_reduced_nd_sbp"
value {
at_list_string {
val: "S(0)"
val: "S(2)"
}
}
}
input_order: "in"
output_order: "out"
}
}
op_name2nd_sbp_signature_conf {
key: "model.t5_model.encoder.layers.0.self_attention-reshape-29"
value {
bn_in_op2nd_sbp {
key: "in_0"
value {
sbp_parallel {
split_parallel {
axis: 0
}
}
sbp_parallel {
split_parallel {
axis: 1
}
}
}
}
bn_in_op2nd_sbp {
key: "out_0"
value {
sbp_parallel {
split_parallel {
axis: 0
}
}
sbp_parallel {
split_parallel {
axis: 1
}
}
}
}
}
}
1n4d 和 2n4d 他们两个的区别就是 batch size 变化了 (global),猜测原因和 https://github.com/Oneflow-Inc/OneTeam/issues/1721 里面类似。但为什么 SBP_INFER_RULE_TAG=2 设置了后仍然不能阻止该非期望的 sbp 转换?需要再调试一下。
低效的 amp 转换
代码这里的 broadcast_add 的左边 position_bias
由 compute_bias 计算而来 dtype=float16,右边 attention_mask
由外面传入,原本 dtype=bool,因需经过若干 scalar 计算转为 int64,该处 float16 + int64,然后两者都转换成 float32 计算,后面继续进行 matmul 时又转回 float16。
plan 片段
order : 1667 , actor id : 2199025352758 name : model.t5_model.encoder.layers.0.self_attention-broadcast_add-70 thrd : 1048577 device_type : kCUDA stream_index : 1 {
consume : in_ctrl : <- [ System-ZeRO-ParallelCast-model.t5_model.encoder.layers.2.mlp.wo.weight-repeat-248-242/out_ctrl_640 ] ( actor_id: 2199025356993, regst: regst_num: 1, cuda , ctrl )
consume : in : <- [ model.t5_model.encoder.layers.0.self_attention-cast-69/__out_0 ] ( actor_id: 2199025352757, regst: regst_num: 1, cuda , time_shape: (1,1,8), shape: (16,1,1,512) , dtype: kFloat )
consume : in : <- [ model.t5_model.encoder.layers.0.self_attention-expand_dims-64-out_0-cast_h2f/__out_0 ] ( actor_id: 2199025356805, regst: regst_num: 1, cuda , time_shape: (1,1,8), shape: (1,12,512,512) , dtype: kFloat )
produce : __z_0 regst: regst_num: 1, cuda , time_shape: (1,1,8), shape: (16,12,512,512) , dtype: kFloat {
-> [ model.t5_model.encoder.layers.0.self_attention-broadcast_add-70-z_0-cast_f2h ] ( actor_id: 2199025356729 )
}
produce : out_ctrl_4504 regst: regst_num: 1, cuda , ctrl {
-> [ model.t5_model.decoder.layers.0.self_attention-transpose-938 ] ( actor_id: 2199025353448 )
}
}
order : 1463 , actor id : 2199025352757 name : model.t5_model.encoder.layers.0.self_attention-cast-69 thrd : 1048577 device_type : kCUDA stream_index : 1 {
consume : in_ctrl : <- [ model.t5_model.encoder.layers.3.self_attention-scalar_mul-279/out_ctrl_632 ] ( actor_id: 2199025352924, regst: regst_num: 1, cuda , ctrl )
consume : in : <- [ model.t5_model.encoder.layers.0.self_attention-scalar_mul-68/__out_0 ] ( actor_id: 2199025352756, regst: regst_num: 1, cuda , time_shape: (1,1,8), shape: (16,1,1,512) , dtype: kInt64 )
produce : __out_0 regst: regst_num: 1, cuda , time_shape: (1,1,8), shape: (16,1,1,512) , dtype: kFloat {
-> [ model.t5_model.encoder.layers.0.self_attention-broadcast_add-70 ] ( actor_id: 2199025352758 )
}
produce : out_ctrl_3032 regst: regst_num: 1, cuda , ctrl {
-> [ model.t5_model.encoder.layers.7.self_attention-scalar_mul-547 ] ( actor_id: 2199025353136 )
}
}
order : 1656 , actor id : 2199025356805 name : model.t5_model.encoder.layers.0.self_attention-expand_dims-64-out_0-cast_h2f thrd : 1048577 device_type : kCUDA stream_index : 1 {
consume : in_ctrl : <- [ System-ZeRO-ParallelCast-model.t5_model.encoder.layers.2.mlp.wi_0.weight-repeat-233-241/out_ctrl_20337 ] ( actor_id: 2199025356992, regst: regst_num: 1, cuda , ctrl )
consume : in : <- [ model.t5_model.encoder.layers.0.self_attention-expand_dims-64/__out_0 ] ( actor_id: 2199025352752, regst: regst_num: 1, cuda , time_shape: (1,1,8), shape: (1,12,512,512) , dtype: kFloat16 )
produce : __out_0 regst: regst_num: 1, cuda , time_shape: (1,1,8), shape: (1,12,512,512) , dtype: kFloat {
-> [ model.t5_model.encoder.layers.0.self_attention-broadcast_add-70 ] ( actor_id: 2199025352758 )
}
produce : out_ctrl_4496 regst: regst_num: 1, cuda , ctrl {
-> [ model.t5_model.decoder.layers.0.self_attention.relative_attention_bias-gather-937 ] ( actor_id: 2199025353447 )
}
}
低效冗余的 cast
上述的 attention_mask
原本是 dtype=bool 的 mask 张量,需要传入到每一层 transformer layer 进行计算,计算在这里:(1 - attention_mask) * -1000
要进行这些计算,系统选择先把 bool cast to int64,该 cast 在每一层 transformer layer 都重复进行。
plan 片段
order : 1247 , actor id : 2199025352711 name : model.t5_model.encoder.layers.0-identity-17 thrd : 1048577 device_type : kCUDA stream_index : 1 {
consume : in_ctrl : <- [ model.t5_model.decoder.layers.0-identity-888/out_ctrl_216 ] ( actor_id: 2199025353405, regst: regst_num: 1, cuda , ctrl )
consume : in : <- [ model.t5_model.extended_attn_mask-expand_dims-9/__out_0 ] ( actor_id: 2199025352706, regst: regst_num: 1, cuda , time_shape: (1,1,8), shape: (16,1,1,512) , dtype: kBool )
produce : __out_0 regst: regst_num: 1, cuda , time_shape: (1,1,8), shape: (16,1,1,512) , dtype: kBool {
-> [ model.t5_model.encoder.layers.11.self_attention-cast-809 ] ( actor_id: 2199025353342 )
-> [ model.t5_model.encoder.layers.4.self_attention-cast-342 ] ( actor_id: 2199025352973 )
-> [ model.t5_model.encoder.layers.3.self_attention-cast-273 ] ( actor_id: 2199025352918 )
-> [ model.t5_model.encoder.layers.3.self_attention-cast-275 ] ( actor_id: 2199025352920 )
-> [ model.t5_model.encoder.layers.11.self_attention-cast-811 ] ( actor_id: 2199025353344 )
-> [ model.t5_model.encoder.layers.2.self_attention-cast-208 ] ( actor_id: 2199025352867 )
-> [ model.t5_model.encoder.layers.2.self_attention-cast-206 ] ( actor_id: 2199025352865 )
-> [ model.t5_model.encoder.layers.0.self_attention-cast-65 ] ( actor_id: 2199025352753 )
-> [ model.t5_model.encoder.layers.6.self_attention-cast-474 ] ( actor_id: 2199025353077 )
-> [ model.t5_model.encoder.layers.6.self_attention-cast-476 ] ( actor_id: 2199025353079 )
-> [ model.t5_model.encoder.layers.9.self_attention-cast-675 ] ( actor_id: 2199025353236 )
-> [ model.t5_model.encoder.layers.0.self_attention-cast-74 ] ( actor_id: 2199025352761 )
-> [ model.t5_model.encoder.layers.8.self_attention-cast-608 ] ( actor_id: 2199025353183 )
-> [ model.t5_model.encoder.layers.4.self_attention-cast-340 ] ( actor_id: 2199025352971 )
-> [ model.t5_model.encoder.layers.1.self_attention-cast-141 ] ( actor_id: 2199025352814 )
-> [ model.t5_model.encoder.layers.5.self_attention-cast-407 ] ( actor_id: 2199025353024 )
-> [ model.t5_model.encoder.layers.1.self_attention-cast-139 ] ( actor_id: 2199025352812 )
-> [ model.t5_model.encoder.layers.5.self_attention-cast-409 ] ( actor_id: 2199025353026 )
-> [ model.t5_model.encoder.layers.0.self_attention-cast-72 ] ( actor_id: 2199025352759 )
-> [ model.t5_model.encoder.layers.7.self_attention-cast-541 ] ( actor_id: 2199025353130 )
-> [ model.t5_model.encoder.layers.7.self_attention-cast-543 ] ( actor_id: 2199025353132 )
-> [ model.t5_model.encoder.layers.8.self_attention-cast-610 ] ( actor_id: 2199025353185 )
-> [ model.t5_model.encoder.layers.9.self_attention-cast-677 ] ( actor_id: 2199025353238 )
-> [ model.t5_model.encoder.layers.10.self_attention-cast-742 ] ( actor_id: 2199025353289 )
-> [ model.t5_model.encoder.layers.10.self_attention-cast-744 ] ( actor_id: 2199025353291 )
}
produce : out_ctrl_208 regst: regst_num: 1, cuda , ctrl {
-> [ model.t5_model.encoder.layers.0-identity-16 ] ( actor_id: 2199025352710 )
}
}
同时 (1 - attention_mask) * -1000
的计算也在每一层重复,应该也是不必要的。比较高级的做法是通过编译技术消除重复计算,但目前应该没这种 pass,如果需要 benchmark 好看,可以修改一下写法,将 attention_mask 的转换和计算都写在外面去,比如手动 cast 成 float16,在进行上述计算。不好的地方在于代码可能与 pytorch 无法对齐(不过我看现在里面已经插入不少人工干预 to_global,所以应该本来就没那么对齐)。
但在2n4d情况下,该 reshape 前面被插入1个 System-NCCL-Logical-(*S)2(*S)-1867,为 (S(0), S(2)) -> (S(0), S(1)) 的转换,导致后面一系列的 op 都不按照预期的 sbp 来推导,最终导致冗余 nccl logical boxing。而 1n4d 情况该处则正常。
reshape只有一个输入的话,哪个sbp规则下都是match的,不可能发生改变吧? 上游是否发生了强制的转换?或者说reshape是否仍然不是源头?
model.t5_model.encoder.layers.0.self_attention-reshape-29
的上一个 op 是 model.t5_model.encoder.layers.0.self_attention.query_key_value-broadcast_matmul-28
,在 1n4d 和 2n4d 下的 sbp signature 是一致的。
emmm,等下我具体看看为什么推导出了不同的sbp
不是 GreedilyFindMinCopyCostNdSbp 这个函数的问题,而是 GetValidNdSbpSignatureList 的问题。
不是 GreedilyFindMinCopyCostNdSbp 这个函数的问题,而是 GetValidNdSbpSignatureList 的问题。
哎,你也在看,这个今天修好了,在让 @xyn1201 测 在底下这个commit,稍后等结果出来会一起解释 https://github.com/Oneflow-Inc/oneflow/pull/9288/commits/caf344f94d3fcf66bbe914f2ef93f4c03b0086b2
我调试看起来不像是这个原因,而是 reshape 的 GetSbp 函数本身有问题。我再 debug 看看具体是什么。
我调试看起来不像是这个原因,而是 reshape 的 GetSbp 函数本身有问题。我再 debug 看看具体是什么。
哎,刚刚测试出修复失败了,我也再debug康康
reshape 的 sbp siganture list 在 1n4d 下正常,而 2n4d 下不正常的原因找到了:
debug log 片段
E20221023 15:37:22.099296 665754 reshape_user_op_util.cpp:179] [GetReshapeUserOpSbpSignatures] model.t5_model.encoder.layers.0.self_attention-reshape-29: (32,512,2304) -> (32,512,12,192), parallel_num=4
0 (origin=0) -> 0 (origin=0)
1 (origin=1) -> 1 (origin=1)
2 (origin=2) -> 2 (origin=2)
E20221023 15:37:22.099376 665754 operator.cpp:519] [GetNdSbpSignatureList] model.t5_model.encoder.layers.0.self_attention-reshape-29, sbp_sig size=5, sbp_sig_list=
(in_0) -> (out_0): [
(S(0)) -> (S(0)),
(S(1)) -> (S(1)),
(S(2)) -> (S(2)),
(P) -> (P),
(B) -> (B),
]
E20221023 15:45:48.437899 680749 reshape_user_op_util.cpp:179] [GetReshapeUserOpSbpSignatures] model.t5_model.encoder.layers.0.self_attention-reshape-29: (64,512,2304) -> (64,512,12,192), parallel_num=8
0 (origin=0) -> 0 (origin=0)
1 (origin=1) -> 1 (origin=1)
E20221023 15:45:48.437943 680749 operator.cpp:519] [GetNdSbpSignatureList] model.t5_model.encoder.layers.0.self_attention-reshape-29, sbp_sig size=4, sbp_sig_list=
(in_0) -> (out_0): [
(S(0)) -> (S(0)),
(S(1)) -> (S(1)),
(P) -> (P),
(B) -> (B),
]
代码在: https://github.com/Oneflow-Inc/oneflow/blob/22eabed6a2432085cd4aa7bf7bf98464d30e9cba/oneflow/user/ops/reshape_user_op_util.cpp#L131-L132) 处判断当前 dimension 是否可以被 split 的时候是用 % parallel_num
来判断的。
- 1n4d 下 reshape (32,512,2304) to (32,512,12,192), parallel_num=4, dim(2) == 12 被认为是可以 split 的
- 2n4d 下 reshape (64,512,2304) to (64,512,12,192), parallel_num=8, dim(2) == 12 认为是不可以 split 的
所以在 2n4d 下我们根据调试的信息可以看到 reshape 的 sbp signature list 里面没有 S(2) -> S(2) 这一项,但其实是可以 split 的,因为 4dp + 2mp,S(2) 要不切4份,要不切2份(取决于 S(2) 是 nd_sbp 的第1维还是第2维),12 % 4 == 0
和 12 % 2 == 0
都成立。
所以出现了 https://github.com/Oneflow-Inc/libai/issues/406#issuecomment-1287831082 里面所说的情况。这里的正确做法,应该根据 device mesh 的某一个维来判断是否能被 split,而不能只看 parallel_num。
但目前有一些困难,因为在 GetSbp 的时候,并不知晓推导的 sbp signature 将会被应用于 device mesh 的哪一维。这里只能添加上全部的 split(num_axes),然后再到后面的 FilterNdSbpSignatureListByLogicalShape
或其他什么地方去 filter。
reshape 的 sbp siganture list 在 1n4d 下正常,而 2n4d 下不正常的原因找到了:
debug log 片段 代码在: https://github.com/Oneflow-Inc/oneflow/blob/22eabed6a2432085cd4aa7bf7bf98464d30e9cba/oneflow/user/ops/reshape_user_op_util.cpp#L131-L132) 处判断当前 dimension 是否可以被 split 的时候是用 % parallel_num 来判断的。
1n4d 下 reshape (32,512,2304) to (32,512,12,192), parallel_num=4, dim(2) == 12 被认为是可以 split 的 2n4d 下 reshape (64,512,2304) to (64,512,12,192), parallel_num=8, dim(2) == 12 认为是不可以 split 的 所以在 2n4d 下我们根据调试的信息可以看到 reshape 的 sbp signature list 里面没有 S(2) -> S(2) 这一项,但其实是可以 split 的,因为 4dp + 2mp,S(2) 要不切4份,要不切2份(取决于 S(2) 是 nd_sbp 的第1维还是第2维),12 % 4 == 0 和 12 % 2 == 0 都成立。
所以出现了 https://github.com/Oneflow-Inc/libai/issues/406#issuecomment-1287831082 里面所说的情况。这里的正确做法,应该根据 device mesh 的某一个维来判断是否能被 split,而不能只看 parallel_num。
但目前有一些困难,因为在 GetSbp 的时候,并不知晓推导的 sbp signature 将会被应用于 device mesh 的哪一维。这里只能添加上全部的 split(num_axes),然后再到后面的 FilterNdSbpSignatureListByLogicalShape 或其他什么地方去 filter。
是的,原因就跟文晓讲的差不多。通过打印log可以看出来
op: `model.t5_model.encoder.layers.0.self_attention-reshape-29` can't find available sbp signature.
candidate nd sbp signature are: (in_0) -> (out_0): [
((S(0), S(0))) -> ((S(0), S(0))),
((S(0), S(1))) -> ((S(0), S(1))),
((S(0), P)) -> ((S(0), P)),
((S(0), B)) -> ((S(0), B)),
((S(1), S(0))) -> ((S(1), S(0))),
((S(1), S(1))) -> ((S(1), S(1))),
((S(1), P)) -> ((S(1), P)),
((S(1), B)) -> ((S(1), B)),
((P, S(0))) -> ((P, S(0))),
((P, S(1))) -> ((P, S(1))),
((P, P)) -> ((P, P)),
((P, B)) -> ((P, B)),
((B, S(0))) -> ((B, S(0))),
((B, S(1))) -> ((B, S(1))),
((B, P)) -> ((B, P)),
((B, B)) -> ((B, B)),
], but inputs sbp are: in_0: (S(0), S(2));
select idx: 1
备选策略里面没有S(2)。原因就是因为12不被8整除。reshape的sbp这部分是我之前重构的。为什么用的是parallel num,是因为reshape的get sbp函数只推导1d的sbp。而且一般这个1d的sbp推导是不涉及shape的,比如矩阵乘或者是加减这些op。在后面还有一个Filter,这个Filter做的才是根据shape筛选sbp。但是reshape本身跟shape又紧密关联,所以这里才必须要有这个filter。
2d sbp是根据1d sbp的直积得出,1d sbp把S(2) filter掉了,后面自然选不到 (S0, S2)。 那怎么修复呢? 添加所有的split是不行的,reshape的split需要划分一个对应组,只有组的头被整除时能被split。 举一个例子: (32,512,2304, 100) to (32,512,12,192, 100) 组头分别对应 32 -> 32, 512 -> 512, 2304 -> 12, 100 -> 100 也就是 S0 -> S0, S1 ->S1, S2 -> S2, S3 ->S4 阔以看到输出的sbp是不能有 S3的,第三维不是组头。
昨天我做了一个修复尝试 https://github.com/Oneflow-Inc/oneflow/pull/9288/commits/caf344f94d3fcf66bbe914f2ef93f4c03b0086b2 就是在挑选1d sbp的时候hierarchy只保留大于1的最低值。 比如 [2, 4] -> [2, 1] 比如 [4, 2, 2] -> [1, 2, 1] 比如 [16, 4, 8] -> [1, 4, 1] 这样在当这个最低值能被其他维度整除的时候,GetSbp才能给出一个完整的1d sbp备选策略。 为什么使用一个最低值而不直接使用1呢?因为怕有的op对于parallel num为1的hierarchy直接给出一个B。
只是测试结果并没有如愿修复bug。 @xyn1201 做了测试 关自动并行 5331 MiB/41.46 samples/s 开自动并行 9915 MiB/59.27 samples/s
点进吞吐可以看到log,S2还是没有出现。原因未知,不过今天稍微修复一下应该就行了。 总而言之,这个问题的根本已经找到了,修复起来比较简单。
debug_reshape_sbp_signature分支
- https://github.com/Oneflow-Inc/oneflow/commit/4b04b25f521ab2d7727235347c057e3aa584350b
- 2n4g mb16_gb512
refactor-GetSbpSignature分支
- https://github.com/Oneflow-Inc/oneflow/pull/9304/commits/195b0ea149c77374737751356b97f6bf2da240ff
- 2n4g mb4_gb128
- 2n4g mb16_gb512
2个分支吞吐都有接近1倍的提升,但还低于megatron
哎,refactor-GetSbpSignature 也测一下 2n4g mb16_gb512 康康是否会有内存暴涨的问题,然后输出一下boxing的log @xyn1201
refactor-GetSbpSignature分支
Producer (S(0), S(2)), placement: hierarchy: (4,2), device: cuda
Shape: (16,512,2304)
idx: 0, sbp: (S(0), S(0)), placement: hierarchy: (4,2), device: cuda
idx: 1, sbp: (S(0), S(2)), placement: hierarchy: (4,2), device: cuda
op: `model.t5_model.encoder.layers.0.self_attention-reshape-29` can't find available sbp signature.
candidate nd sbp signature are: (in_0) -> (out_0): [
((S(0), S(0))) -> ((S(0), S(0))),
((S(0), S(2))) -> ((S(0), S(2))),
((S(0), S(1))) -> ((S(0), S(1))),
((S(0), P)) -> ((S(0), P)),
((S(0), B)) -> ((S(0), B)),
((S(2), S(0))) -> ((S(2), S(0))),
((S(2), S(2))) -> ((S(2), S(2))),
((S(2), S(1))) -> ((S(2), S(1))),
((S(2), P)) -> ((S(2), P)),
((S(2), B)) -> ((S(2), B)),
((S(1), S(0))) -> ((S(1), S(0))),
((S(1), S(2))) -> ((S(1), S(2))),
((S(1), S(1))) -> ((S(1), S(1))),
((S(1), P)) -> ((S(1), P)),
((S(1), B)) -> ((S(1), B)),
((P, S(0))) -> ((P, S(0))),
((P, S(2))) -> ((P, S(2))),
((P, S(1))) -> ((P, S(1))),
((P, P)) -> ((P, P)),
((P, B)) -> ((P, B)),
((B, S(0))) -> ((B, S(0))),
((B, S(2))) -> ((B, S(2))),
((B, S(1))) -> ((B, S(1))),
((B, P)) -> ((B, P)),
((B, B)) -> ((B, B)),
], but inputs sbp are: in_0: (S(0), S(2));
select idx: 1
把S2加回来了,但是吞吐只有70%,还是需要定位一下其他的问题。
操作失误,上面测试的megatron数据是关掉zero的, 所以重测了megatron开zero,并在下方整理现有的对比结果
开zero测试
- oneflow debug_reshape_sbp_signature分支 https://github.com/Oneflow-Inc/oneflow/commit/4b04b25f521ab2d7727235347c057e3aa584350b
-
export SBP_INFER_RULE_TAG=2
- libai main https://github.com/Oneflow-Inc/libai/commit/e9ca4087cb35b3ad268534ee60456db689e36063
- 2n4g mb16_gb512
嗯嗯,这样容易接受多了