ncnn icon indicating copy to clipboard operation
ncnn copied to clipboard

ARM: MultiHeadAttention fp16s/a bf16s

Open EdVince opened this issue 1 year ago • 8 comments

现在的arm的multiheadattention只有我好几个月前pr的neon fp32 pack4的实现,这次pr把剩下的补齐:

  1. fp32 pack1
  2. fp16s pack1/4/8
  3. fp16sa pack1/4/8
  4. bf16s pack1/4 & naive

EdVince avatar Aug 12 '22 09:08 EdVince

Codecov Report

Merging #4139 (92d6fc5) into master (acbaaa6) will decrease coverage by 0.03%. The diff coverage is 90.59%.

:exclamation: Current head 92d6fc5 differs from pull request most recent head 5bbf518. Consider uploading reports for the commit 5bbf518 to get more accurate results

@@            Coverage Diff             @@
##           master    #4139      +/-   ##
==========================================
- Coverage   94.43%   94.40%   -0.04%     
==========================================
  Files         748      749       +1     
  Lines      179004   180668    +1664     
==========================================
+ Hits       169046   170551    +1505     
- Misses       9958    10117     +159     
Impacted Files Coverage Δ
src/layer/arm/multiheadattention_arm_asimdhp.cpp 84.54% <84.54%> (ø)
src/layer/arm/multiheadattention_arm.cpp 98.66% <98.62%> (-0.54%) :arrow_down:

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

codecov-commenter avatar Aug 12 '22 09:08 codecov-commenter

这几个fail掉的都是在test的时候精度不够报错的。 在fp16sa下,有几个小问题:

  1. 在测试输入的blob的尺寸太大的时候,会出现fp16sa的误差超过”1“,例如把输入blob从(32,128)增大到(64,256)就有可能会出错了
  2. 由于mha里面有softmax激活,所以结果数据会有一种荡来荡去的分布,在大部分的数据点都是满足精度的(这些点的数值往往在两位数的大小),但在一些较小的激活点值就会超出精度(就是fail的情况,这些数值只有零点几的大小)
  3. 在失败的数值点,有一个现象就是期望值和具体计算出的值,有围绕0的趋势,就是一个是正的零点几,另一个是负的零点几
  4. 就算是fail了,但最终算出来的数据跟期望数还是大差不差的

up有什么看法呢?不知道是我写的有问题,还是mha计算链太长了,全用fp16sa精度遭不住,在0周围精度崩掉了。

EdVince avatar Aug 13 '22 09:08 EdVince

写一下 int8 的呗... 这么做和 PR 3940 是冲突的...

tpoisonooo avatar Aug 15 '22 02:08 tpoisonooo

写一下 int8 的呗... 这么做和 PR 3940 是冲突的...

3940我看您好像都做完了呀,而且您的int8是naive实现呀,应该不会冲突的吧?

EdVince avatar Aug 15 '22 02:08 EdVince

没写 arm 的 int8

tpoisonooo avatar Aug 15 '22 02:08 tpoisonooo

没写 arm 的 int8

好嘞,不过我得先去学一下int8是咋整的

EdVince avatar Aug 15 '22 02:08 EdVince

很简单的,就是 weight/input 用 int8。 softmax 那个地方量化了会掉点,我试过 int4 softmax.

tpoisonooo avatar Aug 15 '22 06:08 tpoisonooo

哪天有空我得再试试 int4 softmax,贼心不死。

tpoisonooo avatar Aug 15 '22 07:08 tpoisonooo