External-Attention-pytorch copied to clipboard
Update PSA.py
import numpy as np import torch from torch import nn from torch.nn import init
class PSA(nn.Module):
def __init__(self, channel=512, reduction=4, S=4):
self.S = S
self.convs = nn.ModuleList(
[nn.Conv2d(channel // S, channel // S, kernel_size=2 * (i + 1) + 1, padding=(i + 1)) for i in range(S)])
# self.convs=[]
# for i in range(S):
# self.convs.append(nn.Conv2d(channel//S,channel//S,kernel_size=2*(i+1)+1,padding=i+1))
self.se_blocks = nn.ModuleList(
nn.Conv2d(channel // S, channel // (S * reduction), kernel_size=1, bias=False),
nn.Conv2d(channel // (S * reduction), channel // S, kernel_size=1, bias=False),
) for i in range(S)
# self.se_blocks=[]
# for i in range(S):
# self.se_blocks.append(nn.Sequential(
# nn.AdaptiveAvgPool2d(1),
# nn.Conv2d(channel//S, channel // (S*reduction),kernel_size=1, bias=False),
# nn.ReLU(inplace=True),
# nn.Conv2d(channel // (S*reduction), channel//S,kernel_size=1, bias=False),
# nn.Sigmoid()
# ))
self.softmax = nn.Softmax(dim=1)
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)
def forward(self, x):
b, c, h, w = x.size()
#Step1:SPC module
SPC_out=x.view(b,self.S,c//self.S,h,w) #bs,s,ci,h,w
for idx,conv in enumerate(self.convs):
#Step2:SE weight
for idx,se in enumerate(self.se_blocks):
return PSA_out
if name == 'main': device = torch.device('cuda') input = torch.randn(8, 512, 7, 7).to(device) psa = PSA(channel=512, reduction=8).to(device) output = psa(input) a = output.view(-1).sum() a.backward() print(output.shape)
解决了PSA.py模块用到自己的网络中时会出现 RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same的问题 不过还存在就地操作问题,梯度计算出错,希望大佬能够帮忙解决一下 RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [8, 128, 7, 7]], which is output 0 of AsStridedBackward0, is at version 4; expected version 3 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [4, 64, 20, 20]], which is output 0 of ReluBackward1, is at version 5; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).只改这个nn.ModuleList,确实出现一样的问题