pytorch-groupnormalization
pytorch-groupnormalization copied to clipboard
Questions about expand the resnet.py to 3d and hope you can provide an example of 3D group normalization.
Hello chengyangfu! Thank you for providing such a great work! When I expanded the resnet.py to 3d and divided the number of groups into 32 groups , I encountered a problem: RuntimeError: Expected 2 to 5 dimensions, but got 6-dimensional tensor for argument #1 'input' (while checking arguments for cudnn_batch_norm). Could you please tell me why this problem occurs and how to solve it? Could you please provide me with an example using 3d group normalization? Thank you very much! Looking forward to your reply! Best regards!
PS: I will attach the error trace and the code I modified. Thanks again!
this is the error trace:
Traceback (most recent call last):
File "D:/LiuJiaqi/resnet_gn/models/resnet+gn_3d_changegroup.py", line 260, in
Process finished with exit code 1
this is the code I modified: `import torch import torch.nn as nn import math import torch.utils.model_zoo as model_zoo from group_norm import GroupNorm3d # 3d
all = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']
model_urls = { 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', }
def conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding""" return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
def norm3d(planes, num_channels_per_group=32): print("num_channels_per_group:{}".format(num_channels_per_group)) if num_channels_per_group > 0: return GroupNorm3d(planes, num_channels_per_group, affine=True, track_running_stats=False) else: return nn.BatchNorm3d(planes)
class BasicBlock(nn.Module): expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None,
group_norm=32):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm3d(planes, group_norm)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm3d(planes, group_norm)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Bottleneck(nn.Module): expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None,
group_norm=32):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = norm3d(planes, group_norm)
self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = norm3d(planes, group_norm)
self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = norm3d(planes * 4, group_norm)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000, group_norm=32):
self.inplanes = 64
super(ResNet, self).__init__()
self.conv1 = nn.Conv3d(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = norm3d(64, group_norm)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0],
group_norm=group_norm)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
group_norm=group_norm)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
group_norm=group_norm)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
group_norm=group_norm)
self.avgpool = nn.AvgPool3d(7, stride=1)
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv3d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm3d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, GroupNorm3d):
m.weight.data.fill_(1)
m.bias.data.zero_()
for m in self.modules():
if isinstance(m, Bottleneck):
m.bn3.weight.data.fill_(0)
if isinstance(m, BasicBlock):
m.bn2.weight.data.fill_(0)
def _make_layer(self, block, planes, blocks, stride=1, group_norm=0):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv3d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
norm3d(planes * block.expansion, group_norm),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample,
group_norm))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, group_norm=group_norm))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
def resnet18(pretrained=False, **kwargs): """Constructs a ResNet-18 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
return model
def resnet34(pretrained=False, **kwargs): """Constructs a ResNet-34 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
return model
def resnet50(pretrained=False, **kwargs): """Constructs a ResNet-50 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
return model
def resnet101(pretrained=False, **kwargs): """Constructs a ResNet-101 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
return model
def resnet152(pretrained=False, **kwargs): """Constructs a ResNet-152 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
return model
if name == 'main': import os
# os.environ['CUDA_VISIBLE_DEVICES'] = "1,0"
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
# [batch,channel,H,W]
#img = torch.rand(2, 3, 224, 224).cuda()
img = torch.rand(2, 3, 224, 224, 224).cuda()# after changing to 3d, it is modified here.
# (2, 3, 224, 224, 224) (B,C,D,H,W)
net = resnet18(num_classes=2, group_norm=32).cuda().train()
result = net(img)
print(result.shape)
`