mindcv
mindcv copied to clipboard
Convit GPSA Layer 少了attention normalization, 跟原论文和torch versions不一样
Mindcv 实现的GPSA layer代码 中get_attention
函数没有对attn
进行normalization:
def get_attention(self, x: Tensor) -> Tensor:
B, N, C = x.shape
q = ops.reshape(self.q(x), (B, N, self.num_heads, C // self.num_heads))
q = ops.transpose(q, (0, 2, 1, 3))
k = ops.reshape(self.k(x), (B, N, self.num_heads, C // self.num_heads))
k = ops.transpose(k, (0, 2, 3, 1))
pos_score = self.pos_proj(self.rel_indices)
pos_score = ops.transpose(pos_score, (0, 3, 1, 2))
pos_score = self.softmax(pos_score)
patch_score = self.batch_matmul(q, k)
patch_score = ops.mul(patch_score, self.scale)
patch_score = self.softmax(patch_score)
gating = ops.reshape(self.gating_param, (1, -1, 1, 1))
gating = ops.Sigmoid()(gating)
attn = (1.0 - gating) * patch_score + gating * pos_score
attn = self.attn_drop(attn)
return attn
def get_attention(self, x):
B, N, C = x.shape
qk = self.qk(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k = qk[0], qk[1]
pos_score = self.rel_indices.expand(B, -1, -1, -1)
pos_score = self.pos_proj(pos_score).permute(0, 3, 1, 2)
patch_score = (q @ k.transpose(-2, -1)) * self.scale
patch_score = patch_score.softmax(dim=-1)
pos_score = pos_score.softmax(dim=-1)
gating = self.gating_param.view(1, -1, 1, 1)
attn = (1. - torch.sigmoid(gating)) * patch_score + torch.sigmoid(gating) * pos_score
attn /= attn.sum(dim=-1).unsqueeze(-1) # attention normalized by its sum
attn = self.attn_drop(attn)
return attn
虽然并不清楚这个normalization对performance的影响是大还是小,但是我认为最好跟原论文保持一致。