ncnn
ncnn copied to clipboard
pnnx关于多头注意力识别问题
据我观测pnnx的代码,对于多头注意力的这个layer识别上有一些缺陷。
pnnx目前只能支持torch自带的多头注意力,这对于相关落地似乎不是很方便(例如transformer.py中bert的多头注意力)。
主要问题集中在两个方面,一个是缩放点积注意力需要填写scale参数,但是在torch的文档中scale是可省却的(默认值是1 / math.sqrt(query.size(-1)) )。
另一方面是 不同库对于多头注意力 使用的 view-reshape,permute-transpose 这类同义算子没法做到很好的识别。
希望能为 ncnn 提供缩放点积注意力算子。
例如 BertSdpaSelfAttention 的 IR 应该是这样写
15 14
pnnx.Input input 0 1 input
nn.Linear op_0 1 1 input q bias=%qbias in_features=%qdim out_features=%embed_dim @bias @weight
Tensor.view op_3 1 1 q 10 shape=(%batch,%size,%num_heads,%feat_per_head)
Tensor.permute op_6 1 1 10 16 dims=(0,2,1,3)
nn.Linear op_1 1 1 input k bias=%kbias in_features=%kdim out_features=%embed_dim @bias @weight
Tensor.view op_4 1 1 k 12 shape=(%batch,%size,%num_heads,%feat_per_head)
Tensor.permute op_7 1 1 12 17 dims=(0,2,1,3)
nn.Linear op_2 1 1 input v bias=%vbias in_features=%vdim out_features=%embed_dim @bias @weight
Tensor.view op_5 1 1 v 14 shape=(%batch,%size,%num_heads,%feat_per_head)
Tensor.permute op_8 1 1 14 18 dims=(0,2,1,3)
F.scaled_dot_product_attention op_9 3 1 16 17 18 19 dropout_p=0.0 is_causal=False attn_mask=None
Tensor.transpose op_10 1 1 19 20 dim0=%dim0 dim1=%dim1
Tensor.reshape op_11 1 1 20 21 shape=(%batch,%size,%embed_dim)
nn.Linear out_proj 1 1 21 out bias=%outbias in_features=%embed_dim out_features=%qdim @bias @weight
pnnx.Output output 1 0 out
)PNNXIR
因为pnnx的p的意思就是pytorch(
https://github.com/Tencent/ncnn/pull/6397
https://github.com/Tencent/ncnn/pull/6405