CoordAttention icon indicating copy to clipboard operation
CoordAttention copied to clipboard

how to use it to Remote sensing semantic segmentation

Open wwwmmmqqq opened this issue 3 years ago • 5 comments

how can i use it to Remote sensing semantic segmentation.It cause RuntimeError: The size of tensor a (256) must match the size of tensor b (32) at non-singleton dimension 1. how can i change the result to the four dim to match the base network.

class CoordAtt(nn.Module): def init(self, inp, oup, reduction=32): super(CoordAtt, self).init() self.pool_h = nn.AdaptiveAvgPool2d((None, 1)) self.pool_w = nn.AdaptiveAvgPool2d((1, None))

    mip = max(8, inp // reduction)

    self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
    self.bn1 = nn.BatchNorm2d(mip)
    self.act = h_swish()
    
    self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
    self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
    

def forward(self, x):
    identity = x
    
    n,c,h,w = x.size()
    x_h = self.pool_h(x)
    x_w = self.pool_w(x).permute(0, 1, 3, 2)

    y = torch.cat([x_h, x_w], dim=2)
    y = self.conv1(y)
    y = self.bn1(y)
    y = self.act(y) 
    
    x_h, x_w = torch.split(y, [h, w], dim=2)
    x_w = x_w.permute(0, 1, 3, 2)

    a_h = self.conv_h(x_h).sigmoid()
    a_w = self.conv_w(x_w).sigmoid()

    out = identity * a_w * a_h


    return out

wwwmmmqqq avatar Mar 08 '21 09:03 wwwmmmqqq

Given the code you provided, it is hard to figure out the problem.

houqb avatar Mar 09 '21 05:03 houqb

resnet.py class CABottleneck(nn.Module): expansion = 4 def init(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None): super(Bottleneck, self).init() if norm_layer is None: norm_layer = nn.BatchNorm2d width = int(planes * (base_width / 64.)) * groups # Both self.conv2 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv1x1(inplanes, width) self.bn1 = norm_layer(width) self.conv2 = conv3x3(width, width, stride, groups, dilation) self.bn2 = norm_layer(width) self.conv3 = conv1x1(width, planes * self.expansion) self.bn3 = norm_layer(planes * self.expansion) self.relu = nn.ReLU(inplace=True) self.ca = CoordAtt(planes * self.expansion,reduction) self.downsample = downsample self.stride = stride

def forward(self, x):
    identity = x
    out = self.conv1(x)
    out = self.bn1(out)
    out = self.relu(out)

    out = self.conv2(out)
    out = self.bn2(out)
    out = self.relu(out)

    out = self.conv3(out)
    out = self.bn3(out)
    out = self.ca(out)
    if self.downsample is not None:
        identity = self.downsample(x)
    out += identity
    out = self.relu(out)
    return out

coordatt.py class CoordAtt(nn.Module): def init(self, inp, oup, reduction=32): super(CoordAtt, self).init() self.pool_h = nn.AdaptiveAvgPool2d((None, 1)) self.pool_w = nn.AdaptiveAvgPool2d((1, None)) mip = max(8, inp // reduction) self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0) self.bn1 = nn.BatchNorm2d(mip) self.act = h_swish() self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0) self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0) def forward(self, x): identity = x n,c,h,w = x.size() x_h = self.pool_h(x) x_w = self.pool_w(x).permute(0, 1, 3, 2)

    y = torch.cat([x_h, x_w], dim=2)
    y = self.conv1(y)
    y = self.bn1(y)
    y = self.act(y) 
    x_h, x_w = torch.split(y, [h, w], dim=2)
    x_w = x_w.permute(0, 1, 3, 2)
    a_h = self.conv_h(x_h).sigmoid()
    a_w = self.conv_w(x_w).sigmoid()
    out = identity * a_w * a_h
    return out

the error result is:RuntimeError: The size of tensor a (256) must match the size of tensor b (32) at non-singleton dimension 1

wwwmmmqqq avatar Mar 09 '21 05:03 wwwmmmqqq

CoordAtt(planes * self.expansion,reduction) shoud be CoordAtt(planes * self.expansion, planes * self.expansion, reduction)

houqb avatar Mar 09 '21 07:03 houqb

thank you very munch.but it happened another problems that CUDA out of memory when running.My batch size=6,it is so low that i do not want to make it lower.did you have some ideas to solve it.thank you very much.

wwwmmmqqq avatar Mar 09 '21 11:03 wwwmmmqqq