ResNeSt icon indicating copy to clipboard operation
ResNeSt copied to clipboard

Split-Attention Module in PyTorch

Open zhanghang1989 opened this issue 4 years ago • 5 comments

Just in case someone want to use the Split-Attention Module. The module is provided here:

import torch
import torch.nn as nn
from torch.nn import functional as F

class rSoftMax(nn.Module):
    def __init__(self, radix, cardinality):
        super().__init__()
        assert radix > 0
        self.radix = radix
        self.cardinality = cardinality

    def forward(self, x):
        batch = x.size(0)
        if self.radix > 1:
            x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
            x = F.softmax(x, dim=1)
            x = x.reshape(batch, -1)
        else:
            x = torch.sigmoid(x)
        return x

class Splat(nn.Module):
    def __init__(self, channels, radix, cardinality, reduction_factor=4):
        super(Splat, self).__init__()
        self.radix = radix
        self.cardinality = cardinality
        self.channels = channels
        inter_channels = max(channels*radix//reduction_factor, 32)
        self.fc1 = nn.Conv2d(channels//radix, inter_channels, 1, groups=cardinality)
        self.bn1 = nn.BatchNorm2d(inter_channels)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Conv2d(inter_channels, channels*radix, 1, groups=cardinality)
        self.rsoftmax = rSoftMax(radix, cardinality)

    def forward(self, x):
        batch, rchannel = x.shape[:2]
        if self.radix > 1:
            splited = torch.split(x, rchannel//self.radix, dim=1)
            gap = sum(splited) 
        else:
            gap = x
        gap = F.adaptive_avg_pool2d(gap, 1)
        gap = self.fc1(gap)

        gap = self.bn1(gap)
        gap = self.relu(gap)

        atten = self.fc2(gap)
        atten = self.rsoftmax(atten).view(batch, -1, 1, 1)

        if self.radix > 1:
            attens = torch.split(atten, rchannel//self.radix, dim=1)
            out = sum([att*split for (att, split) in zip(attens, splited)])
        else:
            out = atten * x
        return out.contiguous()

zhanghang1989 avatar May 24 '20 23:05 zhanghang1989

@zhanghang1989 thanks for sharing the split attention module implementation, can we integrate this with FPN module instead of Resnet backbone ??

abhigoku10 avatar Jun 16 '20 15:06 abhigoku10

@zhanghang1989 thanks for sharing the split attention module implementation, can we integrate this with FPN module instead of Resnet backbone ??

Yes, that should work.

zhanghang1989 avatar Jun 16 '20 16:06 zhanghang1989

@zhanghang1989 any example references which you can point out

abhigoku10 avatar Jun 17 '20 10:06 abhigoku10

@Jerryzcn , You have done some similar things, right?

zhanghang1989 avatar Jun 17 '20 14:06 zhanghang1989

Hi! some questions about the parameter channels of the layer fn1 in class Splat, as follows. The one is "what's mean of the parameter channels", the other one is "the input channels of the layer fn1 is channels//radix, or channels?".

Hi-Jingzhi avatar Jun 27 '20 07:06 Hi-Jingzhi