open-aff icon indicating copy to clipboard operation
open-aff copied to clipboard

Update fusion.py

Open littleSpongebob opened this issue 3 years ago • 2 comments

第二次全局注意力模块没有用到

littleSpongebob avatar Jul 23 '21 02:07 littleSpongebob

用到了吧,代码如下

class ResGlobLocaforGlobLocaChaFuse(HybridBlock):
    def __init__(self, channels=64, r=4):
        super(ResGlobLocaforGlobLocaChaFuse, self).__init__()
        inter_channels = int(channels // r)

        with self.name_scope():

            self.local_att = nn.HybridSequential(prefix='local_att')
            self.local_att.add(nn.Conv2D(inter_channels, kernel_size=1, strides=1, padding=0))
            self.local_att.add(nn.BatchNorm())
            self.local_att.add(nn.Activation('relu'))
            self.local_att.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0))
            self.local_att.add(nn.BatchNorm())

            self.global_att = nn.HybridSequential(prefix='global_att')
            self.global_att.add(nn.GlobalAvgPool2D())
            self.global_att.add(nn.Conv2D(inter_channels, kernel_size=1, strides=1, padding=0))
            self.global_att.add(nn.BatchNorm())
            self.global_att.add(nn.Activation('relu'))
            self.global_att.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0))
            self.global_att.add(nn.BatchNorm())

            self.local_att2 = nn.HybridSequential(prefix='local_att2')
            self.local_att2.add(nn.Conv2D(inter_channels, kernel_size=1, strides=1, padding=0))
            self.local_att2.add(nn.BatchNorm())
            self.local_att2.add(nn.Activation('relu'))
            self.local_att2.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0))
            self.local_att2.add(nn.BatchNorm())

            self.global_att2 = nn.HybridSequential(prefix='global_att2')
            self.global_att2.add(nn.GlobalAvgPool2D())
            self.global_att2.add(nn.Conv2D(inter_channels, kernel_size=1, strides=1, padding=0))
            self.global_att2.add(nn.BatchNorm())
            self.global_att2.add(nn.Activation('relu'))
            self.global_att2.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0))
            self.global_att2.add(nn.BatchNorm())

            self.sig1 = nn.Activation('sigmoid')
            self.sig2 = nn.Activation('sigmoid')

    def hybrid_forward(self, F, x, residual):

        xa = x + residual
        xl = self.local_att(xa)
        xg = self.global_att(xa)
        xlg = F.broadcast_add(xl, xg)
        wei = self.sig1(xlg)

        xi = F.broadcast_mul(x, wei) + F.broadcast_mul(residual, 1-wei)
        xl2 = self.local_att2(xi)
        xg2 = self.global_att2(xi)
        xlg2 = F.broadcast_add(xl2, xg2)
        wei2 = self.sig2(xlg2)
        xo = F.broadcast_mul(x, wei2) + F.broadcast_mul(residual, 1-wei2)

        return xo

YimianDai avatar Aug 15 '21 13:08 YimianDai

好吧,我看的是aff_pytorch这个文件夹里的fusion.py

littleSpongebob avatar Aug 16 '21 01:08 littleSpongebob