MSCG-Net
MSCG-Net copied to clipboard
IJRS 2021 paper code
Do you have a version of the code optimised for the IJRS 2021 paper?
Self-constructing graph neural networks to model long-range pixel dependencies for semantic segmentation of remote sensing images(https://www.tandfonline.com/doi/full/10.1080/01431161.2021.1936267?scroll=top&needAccess=true)
The model for IJRS2020 can be easily bulit based on /lib/net/scg_gcn.py
, like as below, and the training pipeline for Vaihingen dataset almost same as DDCM-Net.
from lib.net.scg_gcn import *
class SCG_Net_R50(nn.Module):
def __init__(self, out_channels=6, pretrained=True,
nodes=(28, 28), dropout=0,
enhance_diag=True, aux_pred=True):
super(SCG_Net_R50, self).__init__()
self.aux_pred = aux_pred
self.node_size = nodes
self.num_cluster = out_channels
resnet = models.resnet50()
if pretrained:
# resnet.load_state_dict(torch.load(res50_path))
state_dict = load_state_dict_from_url(model_urls['resnet50'],
progress=True)
resnet.load_state_dict(state_dict)
self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
self.layer1, self.layer2, self.layer3 = resnet.layer1, resnet.layer2, resnet.layer3
self.graph_layers1 = GCN_Layer(1024, 128, bnorm=True, activation=nn.ReLU(True), dropout=dropout)
self.graph_layers2 = GCN_Layer(128, out_channels, bnorm=False, activation=None)
self.scg = SCG_block(in_ch=1024,
hidden_ch=out_channels,
node_size=nodes,
add_diag=enhance_diag,
dropout=dropout)
weight_xavier_init(self.graph_layers1, self.graph_layers2, self.scg)
def forward(self, x):
x_size = x.size()
# x = self.dec0(x)
gx = self.layer3(self.layer2(self.layer1(self.layer0(x))))
B, C, H, W = gx.size()
A, gx, loss, z_hat, gamma = self.scg(gx)
gx, A, _= self.graph_layers2(
self.graph_layers1((gx.reshape(B, -1, C), A, False))) # + gx.reshape(B, -1, C)
if self.aux_pred:
gx += gamma * z_hat
gx = gx.reshape(B, self.num_cluster, self.node_size[0], self.node_size[1])
gx = F.interpolate(gx, (H, W), mode='bilinear', align_corners=False)
if self.training:
return F.interpolate(gx, x_size[2:], mode='bilinear', align_corners=False), loss
else:
return F.interpolate(gx, x_size[2:], mode='bilinear', align_corners=False)