External-Attention-pytorch icon indicating copy to clipboard operation
External-Attention-pytorch copied to clipboard

Update PSA.py

Open nlper01 opened this issue 2 years ago • 1 comments

import numpy as np import torch from torch import nn from torch.nn import init

class PSA(nn.Module):

def __init__(self, channel=512, reduction=4, S=4):
    super().__init__()
    self.S = S

    self.convs = nn.ModuleList(
        [nn.Conv2d(channel // S, channel // S, kernel_size=2 * (i + 1) + 1, padding=(i + 1)) for i in range(S)])
    # self.convs=[]
    # for i in range(S):
    #     self.convs.append(nn.Conv2d(channel//S,channel//S,kernel_size=2*(i+1)+1,padding=i+1))

    self.se_blocks = nn.ModuleList(
        nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channel // S, channel // (S * reduction), kernel_size=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel // (S * reduction), channel // S, kernel_size=1, bias=False),
            nn.Sigmoid()
        ) for i in range(S)
    )
    # self.se_blocks=[]
    # for i in range(S):
    #     self.se_blocks.append(nn.Sequential(
    #         nn.AdaptiveAvgPool2d(1),
    #         nn.Conv2d(channel//S, channel // (S*reduction),kernel_size=1, bias=False),
    #         nn.ReLU(inplace=True),
    #         nn.Conv2d(channel // (S*reduction), channel//S,kernel_size=1, bias=False),
    #         nn.Sigmoid()
    #     ))

    self.softmax = nn.Softmax(dim=1)


def init_weights(self):
    for m in self.modules():
        if isinstance(m, nn.Conv2d):
            init.kaiming_normal_(m.weight, mode='fan_out')
            if m.bias is not None:
                init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            init.constant_(m.weight, 1)
            init.constant_(m.bias, 0)
        elif isinstance(m, nn.Linear):
            init.normal_(m.weight, std=0.001)
            if m.bias is not None:
                init.constant_(m.bias, 0)

def forward(self, x):
    b, c, h, w = x.size()

    #Step1:SPC module
    SPC_out=x.view(b,self.S,c//self.S,h,w) #bs,s,ci,h,w
    for idx,conv in enumerate(self.convs):
        SPC_out[:,idx,:,:,:]=conv(SPC_out[:,idx,:,:,:])

    #Step2:SE weight
    se_out=[]
    for idx,se in enumerate(self.se_blocks):
        se_out.append(se(SPC_out[:,idx,:,:,:]))
    SE_out=torch.stack(se_out,dim=1)
    SE_out=SE_out.expand_as(SPC_out)

    #Step3:Softmax
    softmax_out=self.softmax(SE_out)

    #Step4:SPA
    PSA_out=SPC_out*softmax_out
    PSA_out=PSA_out.view(b,-1,h,w)

    return PSA_out

if name == 'main': device = torch.device('cuda') input = torch.randn(8, 512, 7, 7).to(device) psa = PSA(channel=512, reduction=8).to(device) output = psa(input) a = output.view(-1).sum() a.backward() print(output.shape)

解决了PSA.py模块用到自己的网络中时会出现 RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same的问题 不过还存在就地操作问题,梯度计算出错,希望大佬能够帮忙解决一下 RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [8, 128, 7, 7]], which is output 0 of AsStridedBackward0, is at version 4; expected version 3 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

nlper01 avatar Sep 13 '22 07:09 nlper01

one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [4, 64, 20, 20]], which is output 0 of ReluBackward1, is at version 5; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).只改这个nn.ModuleList,确实出现一样的问题

zeng-cy avatar Dec 02 '22 01:12 zeng-cy