ResidualMaskingNetwork copied to clipboard
Resmasking Forward Function TypeError
import torch import torch.nn as nn from torchvision.models.utils import load_state_dict_from_url from torchvision.models.resnet import ResNet, BasicBlock, Bottleneck
model_urls = { "resnet18": "", "resnet34": "", "resnet50": "", }
class ResMasking(ResNet): def init(self, weight_path=""): super(ResMasking, self).init( block=BasicBlock, layers=[2, 2, 2, 2] ) if weight_path: state_dict = torch.load(weight_path) self.load_state_dict(state_dict, strict=False) else: state_dict = load_state_dict_from_url(model_urls["resnet18"], progress=True) self.load_state_dict(state_dict, strict=False) self.fc = nn.Linear(512, 7)
self.mask1 = self._masking(64, 64, depth=4)
self.mask2 = self._masking(128, 128, depth=3)
self.mask3 = self._masking(256, 256, depth=2)
self.mask4 = self._masking(512, 512, depth=1)
def _masking(self, in_channels, out_channels, depth):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
) for _ in range(depth - 1)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
m = self.mask1(x)
x = x * (1 + m)
x = self.layer2(x)
m = self.mask2(x)
x = x * (1 + m)
x = self.layer3(x)
m = self.mask3(x)
x = x * (1 + m)
x = self.layer4(x)
m = self.mask4(x)
x = x * (1 + m)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
class ResMasking50(ResNet): def init(self, weight_path=""): super(ResMasking50, self).init( block=Bottleneck, layers=[3, 4, 6, 3] ) if weight_path: state_dict = torch.load(weight_path) self.load_state_dict(state_dict, strict=False) else: state_dict = load_state_dict_from_url(model_urls["resnet50"], progress=True) self.load_state_dict(state_dict, strict=False) self.fc = nn.Linear(2048, 7)
self.mask1 = self._masking(256, 256, depth=4)
self.mask2 = self._masking(512, 512, depth=3)
self.mask3 = self._masking(1024, 1024, depth=2)
self.mask4 = self._masking(2048, 2048, depth=1)
def _masking(self, in_channels, out_channels, depth):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
) for _ in range(depth - 1)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
m = self.mask1(x)
x = x * (1 + m)
x = self.layer2(x)
m = self.mask2(x)
x = x * (1 + m)
x = self.layer3(x)
m = self.mask3(x)
x = x * (1 + m)
x = self.layer4(x)
m = self.mask4(x)
x = x * (1 + m)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def resmasking(in_channels=3, num_classes=7, weight_path=""): return ResMasking(weight_path)
def resmasking50_dropout1(in_channels=3, num_classes=7, weight_path=""): model = ResMasking50(weight_path) model.fc = nn.Sequential(nn.Dropout(0.4), nn.Linear(2048, num_classes)) return model
def resmasking_dropout1(in_channels=3, num_classes=7, weight_path=""): model = ResMasking(weight_path) model.fc = nn.Sequential( nn.Dropout(0.4), nn.Linear(512, num_classes) ) return model
def resmasking_dropout2(in_channels=3, num_classes=7, weight_path=""): model = ResMasking(weight_path) model.fc = nn.Sequential( nn.Linear(512, 128), nn.ReLU(), nn.Dropout(p=0.5), nn.Linear(128, num_classes), ) return model
def resmasking_dropout3(in_channels=3, num_classes=7, weight_path=""): model = ResMasking(weight_path) model.fc = nn.Sequential( nn.Linear(512, 512), nn.ReLU(True), nn.Dropout(), nn.Linear(512, 128), nn.ReLU(True), nn.Dropout(), nn.Linear(128, num_classes), ) return model
TypeError: ResMasking.forward() got an unexpected keyword argument 'in_channels' def main(config_path): """ This is the main function to make the training up
config_path : srt
path to config file
# load configs and set random seed
configs = json.load(open(config_path))
configs["cwd"] = os.getcwd()
# load model and data_loader
model = get_model(configs)
train_set, val_set, test_set = get_dataset(configs)
# init trainer and make a training
# from trainers.fer2013_trainer import FER2013Trainer
# from trainers.centerloss_trainer import FER2013Trainer
trainer = FER2013Trainer(model, train_set, val_set, test_set, configs)
if configs["distributed"] == 1:
ngpus = torch.cuda.device_count()
mp.spawn(trainer.train, nprocs=ngpus, args=())
def get_model(configs): # Assuming 'arch' in configs matches 'vgg19_bn_mask_pretrain' if configs["arch"] == "resmasking_dropout3": # Directly return the imported model architecture model = resmasking_dropout3(
return model
# Handle case where 'arch' does not match
raise ValueError(f"Model architecture {configs['arch']} is not supported.")
def get_dataset(configs): """ This function get raw dataset """
# todo: add transform
train_set = fer2013("train", configs)
val_set = fer2013("val", configs)
test_set = fer2013("test", configs, tta=True, tta_size=10)
return train_set, val_set, test_set
if name == "main": main("/content/drive/MyDrive/Resnet/fer2013_config.json")