CCNet icon indicating copy to clipboard operation
CCNet copied to clipboard

将RCCA应用到视频任务中,loss没有完全收敛

Open Lanezzz opened this issue 3 years ago • 1 comments

您好,我打算将您公布的pytorch版本的RCCA模块应用到视频的不同帧之间,以获得帧与帧之间的注意力进而增强视频帧的特征表示。主要问题是loss没有完全收敛,维持在1-2中间。我想排除一下是不是我网络改的有问题,需要您的帮助!!!

主要任务是视频的显著性检测,取同一视频中任意两帧经过同一ResNet-101,获得 B x 256 x 47 x 47的特征,然后再输入到RCCA模块,先得到 Q_X , K_X , V_X , Q_Y, K_Y, V_Y,即得到两帧映射到Q,K,V空间的特征。然后再用 Q_X 和 K_Y 做相关性矩阵,作用到V_Y,然后是Q_Y 和 K_X 做相关性,作用到 V_X。 代码的实现如下,几乎没怎么改动,希望您能帮我看一眼,或者能提供给我一点训练或者调参的建议吗?感谢!!!

`class RCCAModule(nn.Module): def init(self, in_channels, out_channels = 256): super(RCCAModule, self).init()

    #inter_channels = in_channels // 4


    self.cca = CrissCrossAttention(in_channels)

    self.convbX = nn.Sequential(nn.Conv2d(in_channels, in_channels, 3, padding=1, bias=True),
                               nn.BatchNorm2d(in_channels))

    self.convbY = nn.Sequential(nn.Conv2d(in_channels, in_channels, 3, padding=1, bias=True),
                                nn.BatchNorm2d(in_channels))

    self.bottleneckX = nn.Sequential(
        nn.Conv2d(in_channels * 2, out_channels, kernel_size=3, padding=1, dilation=1, bias=True),
        nn.BatchNorm2d(out_channels),
        #nn.Dropout2d(0.1),  # dropout在这也会有用吗??
        )

    self.bottleneckY = nn.Sequential(
        nn.Conv2d(in_channels * 2, out_channels, kernel_size=3, padding=1, dilation=1, bias=True),
        nn.BatchNorm2d(out_channels),
        #nn.Dropout2d(0.1),  # dropout在这也会有用吗??
        )



def forward(self, x, y, recurrence=2):
    #outputX = self.convaX(x)
    #outputY = self.convaY(y)
    outputX = x
    outputY = y
    for i in range(recurrence):
        outputX, outputY = self.cca(outputX, outputY)

    outputX = self.convbX(outputX)
    outputY = self.convbY(outputY)

    outputX = self.bottleneckX(torch.cat([x, outputX], 1))
    outputY = self.bottleneckY(torch.cat([y, outputY], 1))

    return outputX, outputY`

` import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import Softmax

def INF(B,H,W):

return -torch.diag(torch.tensor(float("inf")).cuda().repeat(H),0).unsqueeze(0).repeat(B*W, 1, 1)  # 主对角线为-inf的矩阵

class CrissCrossAttention(nn.Module): """ Criss-Cross Attention Module""" def init(self, in_dim): super(CrissCrossAttention,self).init() # 下面三个是转成Q,K,V之前的降维,V不变 self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//2, kernel_size=1) self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//2, kernel_size=1) self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) self.softmax = Softmax(dim=3) self.INF = INF self.gamma1 = nn.Parameter(torch.zeros(1)) # 虽然初始化为0了,但是它是一个可以学习的参数,当插入在模型中时,最开始可以保证从 self.gamma2 = nn.Parameter(torch.zeros(1)) # self.gamma2 = torch.zeros(1).cuda().requires_grad_()

    # ImageNet上学来的特征,然后再慢慢学习,会得到一个值,这可以使得整个训练过程更加的平滑


def forward(self, x, y):

    m_batchsize, _, height, width = x.size()  # B x 2C x H x W ,m_batchsize = 2, _ = 256, height = 47, width = 47
    proj_query_X = self.query_conv(x) # 降维,我改成了128,即降维一半, B,C,H,W
    proj_query_X_H = proj_query_X.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height).permute(0, 2, 1) # BW,H,C
    proj_query_X_W = proj_query_X.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width).permute(0, 2, 1) # BH,W,C
    proj_key_X = self.key_conv(x) # 降维  B,C,H,W
    proj_key_X_H = proj_key_X.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height) # 12,8,5, BW,C,H
    proj_key_X_W = proj_key_X.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width) # 10,8,6, BH,C,W
    proj_value_X = self.value_conv(x)  # 2,64,5,6 就是没有降维而已
    proj_value_X_H = proj_value_X.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height) # 12,64,5 BW,2C,H
    proj_value_X_W = proj_value_X.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width) # 10,64,6 BH,2C,W


    proj_query_Y = self.query_conv(y) # 降维 B,C,W,H
    proj_query_Y_H = proj_query_Y.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height).permute(0, 2, 1) # BW,H,C
    proj_query_Y_W = proj_query_Y.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width).permute(0, 2, 1) # BH,W,C
    proj_key_Y = self.key_conv(y) # 降维  B,C,W,H
    proj_key_Y_H = proj_key_Y.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height) # 12,8,5, BW,C,H
    proj_key_Y_W = proj_key_Y.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width) # 10,8,6, BH,C,W
    proj_value_Y = self.value_conv(y)  # 2,64,5,6 就是没有降维而已
    proj_value_Y_H = proj_value_Y.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height) # 12,64,5 BW,2C,H
    proj_value_Y_W = proj_value_Y.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width) # 10,64,6 BH,2C,W

    A = torch.bmm(proj_query_X_H, proj_key_Y_H)
    B = self.INF(m_batchsize, height, width)
    C = A+B
    # BW,H,H的注意力图中每一列包含了查询帧中的每一个H信息,BH,W,W同理
    energy_X_H = (torch.bmm(proj_query_X_H, proj_key_Y_H)+self.INF(m_batchsize, height, width)).view(m_batchsize,width,height,height).permute(0,2,1,3) # B,H,W,H
    energy_X_W = torch.bmm(proj_query_X_W, proj_key_Y_W).view(m_batchsize,height,width,width)  # B,H,W,W
    concateX = self.softmax(torch.cat([energy_X_H, energy_X_W], 3))  # B,H,W,H+W

    att_X_H = concateX[:,:,:,0:height].permute(0,2,1,3).contiguous().view(m_batchsize*width,height,height)  # BW,H,H
    att_X_W = concateX[:,:,:,height:height+width].contiguous().view(m_batchsize*height,width,width)  # BH,W,W

    # 与X一样
    energy_Y_H = (torch.bmm(proj_query_Y_H, proj_key_X_H)+self.INF(m_batchsize, height, width)).view(m_batchsize,width,height,height).permute(0,2,1,3)
    energy_Y_W = torch.bmm(proj_query_Y_W, proj_key_X_W).view(m_batchsize,height,width,width)  # 2,5,6,6
    concateY = self.softmax(torch.cat([energy_Y_H, energy_Y_W], 3))  # 2,5,6,11

    att_Y_H = concateY[:,:,:,0:height].permute(0,2,1,3).contiguous().view(m_batchsize*width,height,height)  # 12,5,5
    att_Y_W = concateY[:,:,:,height:height+width].contiguous().view(m_batchsize*height,width,width)  # 10,6,6
    # 因为这边permute()相当于做了个转置,所以应当是每一行,包含了查询帧中的每一个H信息
    out_Y_H = torch.bmm(proj_value_Y_H, att_X_H.permute(0, 2, 1)).view(m_batchsize,width,-1,height).permute(0,2,3,1)  # 2,64,5,6
    out_Y_W = torch.bmm(proj_value_Y_W, att_X_W.permute(0, 2, 1)).view(m_batchsize,height,-1,width).permute(0,2,1,3)  # 2,64,5,6

    out_X_H = torch.bmm(proj_value_X_H, att_Y_H.permute(0, 2, 1)).view(m_batchsize,width,-1,height).permute(0,2,3,1)  # 2,64,5,6
    out_X_W = torch.bmm(proj_value_X_W, att_Y_W.permute(0, 2, 1)).view(m_batchsize,height,-1,width).permute(0,2,1,3)  # 2,64,5,6


    return (self.gamma1 * (out_X_H + out_X_W) + x), (self.gamma2 * (out_Y_H + out_Y_W) + y)

`

另外这部分的初始化,卷积层为kaiming初始化,偏置0。BN层权重设为1,偏置0.

Lanezzz avatar Apr 02 '21 03:04 Lanezzz

@speedinghzl

Lanezzz avatar Apr 02 '21 03:04 Lanezzz