ncnn
ncnn copied to clipboard
ARM: MultiHeadAttention fp16s/a bf16s
现在的arm的multiheadattention只有我好几个月前pr的neon fp32 pack4的实现,这次pr把剩下的补齐:
- fp32 pack1
- fp16s pack1/4/8
- fp16sa pack1/4/8
- bf16s pack1/4 & naive
Codecov Report
Merging #4139 (92d6fc5) into master (acbaaa6) will decrease coverage by
0.03%
. The diff coverage is90.59%
.
:exclamation: Current head 92d6fc5 differs from pull request most recent head 5bbf518. Consider uploading reports for the commit 5bbf518 to get more accurate results
@@ Coverage Diff @@
## master #4139 +/- ##
==========================================
- Coverage 94.43% 94.40% -0.04%
==========================================
Files 748 749 +1
Lines 179004 180668 +1664
==========================================
+ Hits 169046 170551 +1505
- Misses 9958 10117 +159
Impacted Files | Coverage Δ | |
---|---|---|
src/layer/arm/multiheadattention_arm_asimdhp.cpp | 84.54% <84.54%> (ø) |
|
src/layer/arm/multiheadattention_arm.cpp | 98.66% <98.62%> (-0.54%) |
:arrow_down: |
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.
这几个fail掉的都是在test的时候精度不够报错的。 在fp16sa下,有几个小问题:
- 在测试输入的blob的尺寸太大的时候,会出现fp16sa的误差超过”1“,例如把输入blob从(32,128)增大到(64,256)就有可能会出错了
- 由于mha里面有softmax激活,所以结果数据会有一种荡来荡去的分布,在大部分的数据点都是满足精度的(这些点的数值往往在两位数的大小),但在一些较小的激活点值就会超出精度(就是fail的情况,这些数值只有零点几的大小)
- 在失败的数值点,有一个现象就是期望值和具体计算出的值,有围绕0的趋势,就是一个是正的零点几,另一个是负的零点几
- 就算是fail了,但最终算出来的数据跟期望数还是大差不差的
up有什么看法呢?不知道是我写的有问题,还是mha计算链太长了,全用fp16sa精度遭不住,在0周围精度崩掉了。
写一下 int8 的呗... 这么做和 PR 3940 是冲突的...
写一下 int8 的呗... 这么做和 PR 3940 是冲突的...
3940我看您好像都做完了呀,而且您的int8是naive实现呀,应该不会冲突的吧?
没写 arm 的 int8
没写 arm 的 int8
好嘞,不过我得先去学一下int8是咋整的
很简单的,就是 weight/input 用 int8。 softmax 那个地方量化了会掉点,我试过 int4 softmax.
哪天有空我得再试试 int4 softmax,贼心不死。