libai
libai copied to clipboard
Use fuse multi head att
batch size = 4, acc step = 8, amp, open Checkpointing
1n1g | use_fuse_multi_head_att = False | use_fuse_multi_head_att = True |
---|---|---|
Throughput | total_throughput: 151.70 samples/s | total_throughput: 155.41 samples/s |
GPU Memory | 3147MiB | 3129MiB |
在encoder
和decoder
中的self_att
和cross_att
中都使用了fuse_multihead_att
.
在28号上简单测了一下,带来的提升有限,应该是transpose的使用次数太多,我下个commit准备把if,else直接取消,默认使用fuse_multihead_att
来测一下.
@chengtbf @strint @ouyangyu @CPFLAME
- 这里魔改了一下,
self_att
和cross_att
都使用了fuse_muti_head_att
,attention
层默认为fuse_multi_head_att
,一共只多出3个必须的transpose
:encode_embedding
的输出进行一次transpose
,decoder_embedding
的输出进行一次transpose
,loss
接收的logits
进行一次transpose
- 如果数据处理的时候直接处理成
[seq_len, batch_size]
的shape
的话上述3个transpose
可以取消 - 用这个pr下面的单测测过了修改后的模型和huggingface对齐:
tests/model_utils/test_mt5_loader_2.py
@chengtbf @CPFLAME @strint @ouyangyu