libai
libai copied to clipboard
Use fuse multi head att
batch size = 4, acc step = 8, amp, open Checkpointing
| 1n1g | use_fuse_multi_head_att = False | use_fuse_multi_head_att = True |
|---|---|---|
| Throughput | total_throughput: 151.70 samples/s | total_throughput: 155.41 samples/s |
| GPU Memory | 3147MiB | 3129MiB |
在encoder和decoder中的self_att和cross_att中都使用了fuse_multihead_att.
在28号上简单测了一下,带来的提升有限,应该是transpose的使用次数太多,我下个commit准备把if,else直接取消,默认使用fuse_multihead_att来测一下.
@chengtbf @strint @ouyangyu @CPFLAME
- 这里魔改了一下,
self_att和cross_att都使用了fuse_muti_head_att,attention层默认为fuse_multi_head_att,一共只多出3个必须的transpose:encode_embedding的输出进行一次transpose,decoder_embedding的输出进行一次transpose,loss接收的logits进行一次transpose - 如果数据处理的时候直接处理成
[seq_len, batch_size]的shape的话上述3个transpose可以取消 - 用这个pr下面的单测测过了修改后的模型和huggingface对齐:
tests/model_utils/test_mt5_loader_2.py
@chengtbf @CPFLAME @strint @ouyangyu