ResNeSt
ResNeSt copied to clipboard
Split-Attention Module in PyTorch
Just in case someone want to use the Split-Attention Module. The module is provided here:
import torch
import torch.nn as nn
from torch.nn import functional as F
class rSoftMax(nn.Module):
def __init__(self, radix, cardinality):
super().__init__()
assert radix > 0
self.radix = radix
self.cardinality = cardinality
def forward(self, x):
batch = x.size(0)
if self.radix > 1:
x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
x = F.softmax(x, dim=1)
x = x.reshape(batch, -1)
else:
x = torch.sigmoid(x)
return x
class Splat(nn.Module):
def __init__(self, channels, radix, cardinality, reduction_factor=4):
super(Splat, self).__init__()
self.radix = radix
self.cardinality = cardinality
self.channels = channels
inter_channels = max(channels*radix//reduction_factor, 32)
self.fc1 = nn.Conv2d(channels//radix, inter_channels, 1, groups=cardinality)
self.bn1 = nn.BatchNorm2d(inter_channels)
self.relu = nn.ReLU(inplace=True)
self.fc2 = nn.Conv2d(inter_channels, channels*radix, 1, groups=cardinality)
self.rsoftmax = rSoftMax(radix, cardinality)
def forward(self, x):
batch, rchannel = x.shape[:2]
if self.radix > 1:
splited = torch.split(x, rchannel//self.radix, dim=1)
gap = sum(splited)
else:
gap = x
gap = F.adaptive_avg_pool2d(gap, 1)
gap = self.fc1(gap)
gap = self.bn1(gap)
gap = self.relu(gap)
atten = self.fc2(gap)
atten = self.rsoftmax(atten).view(batch, -1, 1, 1)
if self.radix > 1:
attens = torch.split(atten, rchannel//self.radix, dim=1)
out = sum([att*split for (att, split) in zip(attens, splited)])
else:
out = atten * x
return out.contiguous()
@zhanghang1989 thanks for sharing the split attention module implementation, can we integrate this with FPN module instead of Resnet backbone ??
@zhanghang1989 thanks for sharing the split attention module implementation, can we integrate this with FPN module instead of Resnet backbone ??
Yes, that should work.
@zhanghang1989 any example references which you can point out
@Jerryzcn , You have done some similar things, right?
Hi! some questions about the parameter channels of the layer fn1 in class Splat, as follows. The one is "what's mean of the parameter channels", the other one is "the input channels of the layer fn1 is channels//radix, or channels?".