erfnet
erfnet copied to clipboard
pretrainedEnc = next(pretrainedEnc.children()).features.encoder
hello what does this line do? pretrainedEnc = next(pretrainedEnc.children()).features.encoder in main.py in train folder
`def main(args): savedir = f'../save/{args.savedir}'
if not os.path.exists(savedir):
os.makedirs(savedir)
with open(savedir + '/opts.txt', "w") as myfile:
myfile.write(str(args))
#Load Model
assert os.path.exists(args.model + ".py"), "Error: model definition not found"
model_file = importlib.import_module(args.model)
model = model_file.Net(NUM_CLASSES)
copyfile(args.model + ".py", savedir + '/' + args.model + ".py")
if args.cuda:
model = torch.nn.DataParallel(model).cuda()
if args.state:
#if args.state is provided then load this state for training
#Note: this only loads initialized weights. If you want to resume a training use "--resume" option!!
"""
try:
model.load_state_dict(torch.load(args.state))
except AssertionError:
model.load_state_dict(torch.load(args.state,
map_location=lambda storage, loc: storage))
#When model is saved as DataParallel it adds a model. to each key. To remove:
#state_dict = {k.partition('model.')[2]: v for k,v in state_dict}
#https://discuss.pytorch.org/t/prefix-parameter-names-in-saved-model-if-trained-by-multi-gpu/494
"""
def load_my_state_dict(model, state_dict): #custom function to load model when not all dict keys are there
own_state = model.state_dict()
for name, param in state_dict.items():
if name not in own_state:
continue
own_state[name].copy_(param)
return model
#print(torch.load(args.state))
model = load_my_state_dict(model, torch.load(args.state))
"""
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
#m.weight.data.normal_(0.0, 0.02)
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif classname.find('BatchNorm') != -1:
#m.weight.data.normal_(1.0, 0.02)
m.weight.data.fill_(1)
m.bias.data.fill_(0)
#TO ACCESS MODEL IN DataParallel: next(model.children())
#next(model.children()).decoder.apply(weights_init)
#Reinitialize weights for decoder
next(model.children()).decoder.layers.apply(weights_init)
next(model.children()).decoder.output_conv.apply(weights_init)
#print(model.state_dict())
f = open('weights5.txt', 'w')
f.write(str(model.state_dict()))
f.close()
"""
#train(args, model)
if (not args.decoder):
print("========== ENCODER TRAINING ===========")
model = train(args, model, True) #Train encoder
#CAREFUL: for some reason, after training encoder alone, the decoder gets weights=0.
#We must reinit decoder weights or reload network passing only encoder in order to train decoder
print("========== DECODER TRAINING ===========")
if (not args.state):
if args.pretrainedEncoder:
print("Loading encoder pretrained in imagenet")
from erfnet_imagenet import ERFNet as ERFNet_imagenet
pretrainedEnc = torch.nn.DataParallel(ERFNet_imagenet(1000))
pretrainedEnc.load_state_dict(torch.load(args.pretrainedEncoder)['state_dict'])
pretrainedEnc = next(pretrainedEnc.children()).features.encoder
if (not args.cuda):
pretrainedEnc = pretrainedEnc.cpu() #because loaded encoder is probably saved in cuda
else:
pretrainedEnc = next(model.children()).encoder
model = model_file.Net(NUM_CLASSES, encoder=pretrainedEnc) #Add decoder to encoder
if args.cuda:
model = torch.nn.DataParallel(model).cuda()
#When loading encoder reinitialize weights for decoder because they are set to 0 when training dec
model = train(args, model, False) #Train decoder
print("========== TRAINING FINISHED ===========")`