CCNet
CCNet copied to clipboard
将RCCA应用到视频任务中,loss没有完全收敛
您好,我打算将您公布的pytorch版本的RCCA模块应用到视频的不同帧之间,以获得帧与帧之间的注意力进而增强视频帧的特征表示。主要问题是loss没有完全收敛,维持在1-2中间。我想排除一下是不是我网络改的有问题,需要您的帮助!!!
主要任务是视频的显著性检测,取同一视频中任意两帧经过同一ResNet-101,获得 B x 256 x 47 x 47的特征,然后再输入到RCCA模块,先得到 Q_X , K_X , V_X , Q_Y, K_Y, V_Y,即得到两帧映射到Q,K,V空间的特征。然后再用 Q_X 和 K_Y 做相关性矩阵,作用到V_Y,然后是Q_Y 和 K_X 做相关性,作用到 V_X。 代码的实现如下,几乎没怎么改动,希望您能帮我看一眼,或者能提供给我一点训练或者调参的建议吗?感谢!!!
`class RCCAModule(nn.Module): def init(self, in_channels, out_channels = 256): super(RCCAModule, self).init()
#inter_channels = in_channels // 4
self.cca = CrissCrossAttention(in_channels)
self.convbX = nn.Sequential(nn.Conv2d(in_channels, in_channels, 3, padding=1, bias=True),
nn.BatchNorm2d(in_channels))
self.convbY = nn.Sequential(nn.Conv2d(in_channels, in_channels, 3, padding=1, bias=True),
nn.BatchNorm2d(in_channels))
self.bottleneckX = nn.Sequential(
nn.Conv2d(in_channels * 2, out_channels, kernel_size=3, padding=1, dilation=1, bias=True),
nn.BatchNorm2d(out_channels),
#nn.Dropout2d(0.1), # dropout在这也会有用吗??
)
self.bottleneckY = nn.Sequential(
nn.Conv2d(in_channels * 2, out_channels, kernel_size=3, padding=1, dilation=1, bias=True),
nn.BatchNorm2d(out_channels),
#nn.Dropout2d(0.1), # dropout在这也会有用吗??
)
def forward(self, x, y, recurrence=2):
#outputX = self.convaX(x)
#outputY = self.convaY(y)
outputX = x
outputY = y
for i in range(recurrence):
outputX, outputY = self.cca(outputX, outputY)
outputX = self.convbX(outputX)
outputY = self.convbY(outputY)
outputX = self.bottleneckX(torch.cat([x, outputX], 1))
outputY = self.bottleneckY(torch.cat([y, outputY], 1))
return outputX, outputY`
` import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import Softmax
def INF(B,H,W):
return -torch.diag(torch.tensor(float("inf")).cuda().repeat(H),0).unsqueeze(0).repeat(B*W, 1, 1) # 主对角线为-inf的矩阵
class CrissCrossAttention(nn.Module): """ Criss-Cross Attention Module""" def init(self, in_dim): super(CrissCrossAttention,self).init() # 下面三个是转成Q,K,V之前的降维,V不变 self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//2, kernel_size=1) self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//2, kernel_size=1) self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) self.softmax = Softmax(dim=3) self.INF = INF self.gamma1 = nn.Parameter(torch.zeros(1)) # 虽然初始化为0了,但是它是一个可以学习的参数,当插入在模型中时,最开始可以保证从 self.gamma2 = nn.Parameter(torch.zeros(1)) # self.gamma2 = torch.zeros(1).cuda().requires_grad_()
# ImageNet上学来的特征,然后再慢慢学习,会得到一个值,这可以使得整个训练过程更加的平滑
def forward(self, x, y):
m_batchsize, _, height, width = x.size() # B x 2C x H x W ,m_batchsize = 2, _ = 256, height = 47, width = 47
proj_query_X = self.query_conv(x) # 降维,我改成了128,即降维一半, B,C,H,W
proj_query_X_H = proj_query_X.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height).permute(0, 2, 1) # BW,H,C
proj_query_X_W = proj_query_X.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width).permute(0, 2, 1) # BH,W,C
proj_key_X = self.key_conv(x) # 降维 B,C,H,W
proj_key_X_H = proj_key_X.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height) # 12,8,5, BW,C,H
proj_key_X_W = proj_key_X.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width) # 10,8,6, BH,C,W
proj_value_X = self.value_conv(x) # 2,64,5,6 就是没有降维而已
proj_value_X_H = proj_value_X.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height) # 12,64,5 BW,2C,H
proj_value_X_W = proj_value_X.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width) # 10,64,6 BH,2C,W
proj_query_Y = self.query_conv(y) # 降维 B,C,W,H
proj_query_Y_H = proj_query_Y.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height).permute(0, 2, 1) # BW,H,C
proj_query_Y_W = proj_query_Y.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width).permute(0, 2, 1) # BH,W,C
proj_key_Y = self.key_conv(y) # 降维 B,C,W,H
proj_key_Y_H = proj_key_Y.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height) # 12,8,5, BW,C,H
proj_key_Y_W = proj_key_Y.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width) # 10,8,6, BH,C,W
proj_value_Y = self.value_conv(y) # 2,64,5,6 就是没有降维而已
proj_value_Y_H = proj_value_Y.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height) # 12,64,5 BW,2C,H
proj_value_Y_W = proj_value_Y.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width) # 10,64,6 BH,2C,W
A = torch.bmm(proj_query_X_H, proj_key_Y_H)
B = self.INF(m_batchsize, height, width)
C = A+B
# BW,H,H的注意力图中每一列包含了查询帧中的每一个H信息,BH,W,W同理
energy_X_H = (torch.bmm(proj_query_X_H, proj_key_Y_H)+self.INF(m_batchsize, height, width)).view(m_batchsize,width,height,height).permute(0,2,1,3) # B,H,W,H
energy_X_W = torch.bmm(proj_query_X_W, proj_key_Y_W).view(m_batchsize,height,width,width) # B,H,W,W
concateX = self.softmax(torch.cat([energy_X_H, energy_X_W], 3)) # B,H,W,H+W
att_X_H = concateX[:,:,:,0:height].permute(0,2,1,3).contiguous().view(m_batchsize*width,height,height) # BW,H,H
att_X_W = concateX[:,:,:,height:height+width].contiguous().view(m_batchsize*height,width,width) # BH,W,W
# 与X一样
energy_Y_H = (torch.bmm(proj_query_Y_H, proj_key_X_H)+self.INF(m_batchsize, height, width)).view(m_batchsize,width,height,height).permute(0,2,1,3)
energy_Y_W = torch.bmm(proj_query_Y_W, proj_key_X_W).view(m_batchsize,height,width,width) # 2,5,6,6
concateY = self.softmax(torch.cat([energy_Y_H, energy_Y_W], 3)) # 2,5,6,11
att_Y_H = concateY[:,:,:,0:height].permute(0,2,1,3).contiguous().view(m_batchsize*width,height,height) # 12,5,5
att_Y_W = concateY[:,:,:,height:height+width].contiguous().view(m_batchsize*height,width,width) # 10,6,6
# 因为这边permute()相当于做了个转置,所以应当是每一行,包含了查询帧中的每一个H信息
out_Y_H = torch.bmm(proj_value_Y_H, att_X_H.permute(0, 2, 1)).view(m_batchsize,width,-1,height).permute(0,2,3,1) # 2,64,5,6
out_Y_W = torch.bmm(proj_value_Y_W, att_X_W.permute(0, 2, 1)).view(m_batchsize,height,-1,width).permute(0,2,1,3) # 2,64,5,6
out_X_H = torch.bmm(proj_value_X_H, att_Y_H.permute(0, 2, 1)).view(m_batchsize,width,-1,height).permute(0,2,3,1) # 2,64,5,6
out_X_W = torch.bmm(proj_value_X_W, att_Y_W.permute(0, 2, 1)).view(m_batchsize,height,-1,width).permute(0,2,1,3) # 2,64,5,6
return (self.gamma1 * (out_X_H + out_X_W) + x), (self.gamma2 * (out_Y_H + out_Y_W) + y)
`
另外这部分的初始化,卷积层为kaiming初始化,偏置0。BN层权重设为1,偏置0.
@speedinghzl