External-Attention-pytorch icon indicating copy to clipboard operation
External-Attention-pytorch copied to clipboard

WeightedPermuteMLP代码中的Linear问题?

Open ZVChen opened this issue 2 years ago • 0 comments

WeightedPermuteMLP 中采用了几个全连接层Linear,具体代码位置在ViP.py中的21-23行

        self.mlp_c=nn.Linear(dim,dim,bias=qkv_bias)
        self.mlp_h=nn.Linear(dim,dim,bias=qkv_bias)
        self.mlp_w=nn.Linear(dim,dim,bias=qkv_bias)

这几个线性层的输入输出通道数都是dim,即输入输出的通道数不变 在forward时,除了mlp_c是直接输入了x没有什么问题

    def forward(self,x) :
        B,H,W,C=x.shape

        c_embed=self.mlp_c(x)

        S=C//self.seg_dim
        h_embed=x.reshape(B,H,W,self.seg_dim,S).permute(0,3,2,1,4).reshape(B,self.seg_dim,W,H*S)
        h_embed=self.mlp_h(h_embed).reshape(B,self.seg_dim,W,H,S).permute(0,3,2,1,4).reshape(B,H,W,C)

        w_embed=x.reshape(B,H,W,self.seg_dim,S).permute(0,3,1,2,4).reshape(B,self.seg_dim,H,W*S)
        w_embed=self.mlp_w(w_embed).reshape(B,self.seg_dim,H,W,S).permute(0,2,3,1,4).reshape(B,H,W,C)

        weight=(c_embed+h_embed+w_embed).permute(0,3,1,2).flatten(2).mean(2)
        weight=self.reweighting(weight).reshape(B,C,3).permute(2,0,1).softmax(0).unsqueeze(2).unsqueeze(2)

        x=c_embed*weight[0]+w_embed*weight[1]+h_embed*weight[2]

        x=self.proj_drop(self.proj(x))

其他的两个线性层在使用时都有问题 可以看到这一步

h_embed=x.reshape(B,H,W,self.seg_dim,S).permute(0,3,2,1,4).reshape(B,self.seg_dim,W,H*S)

最后将通道数改为了H*S ,在执行时如果H*S不等于C,接下来的线性层就会出错了,实际上这一步肯定会错误。 论文当中的代码处理也是类似的方法,不知道怎么解决?

ZVChen avatar Jun 02 '22 08:06 ZVChen