erfnet icon indicating copy to clipboard operation
erfnet copied to clipboard

pretrainedEnc = next(pretrainedEnc.children()).features.encoder

Open MahnazMesgar opened this issue 4 years ago • 0 comments

hello what does this line do? pretrainedEnc = next(pretrainedEnc.children()).features.encoder in main.py in train folder

`def main(args): savedir = f'../save/{args.savedir}'

if not os.path.exists(savedir):
    os.makedirs(savedir)

with open(savedir + '/opts.txt', "w") as myfile:
    myfile.write(str(args))

#Load Model
assert os.path.exists(args.model + ".py"), "Error: model definition not found"
model_file = importlib.import_module(args.model)
model = model_file.Net(NUM_CLASSES)
copyfile(args.model + ".py", savedir + '/' + args.model + ".py")

if args.cuda:
    model = torch.nn.DataParallel(model).cuda()

if args.state:
    #if args.state is provided then load this state for training
    #Note: this only loads initialized weights. If you want to resume a training use "--resume" option!!
    """
    try:
        model.load_state_dict(torch.load(args.state))
    except AssertionError:
        model.load_state_dict(torch.load(args.state,
            map_location=lambda storage, loc: storage))
    #When model is saved as DataParallel it adds a model. to each key. To remove:
    #state_dict = {k.partition('model.')[2]: v for k,v in state_dict}
    #https://discuss.pytorch.org/t/prefix-parameter-names-in-saved-model-if-trained-by-multi-gpu/494
    """
    def load_my_state_dict(model, state_dict):  #custom function to load model when not all dict keys are there
        own_state = model.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                 continue
            own_state[name].copy_(param)
        return model

    #print(torch.load(args.state))
    model = load_my_state_dict(model, torch.load(args.state))

"""
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        #m.weight.data.normal_(0.0, 0.02)
        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        m.weight.data.normal_(0, math.sqrt(2. / n))
    elif classname.find('BatchNorm') != -1:
        #m.weight.data.normal_(1.0, 0.02)
        m.weight.data.fill_(1)
        m.bias.data.fill_(0)

#TO ACCESS MODEL IN DataParallel: next(model.children())
#next(model.children()).decoder.apply(weights_init)
#Reinitialize weights for decoder

next(model.children()).decoder.layers.apply(weights_init)
next(model.children()).decoder.output_conv.apply(weights_init)

#print(model.state_dict())
f = open('weights5.txt', 'w')
f.write(str(model.state_dict()))
f.close()
"""

#train(args, model)
if (not args.decoder):
    print("========== ENCODER TRAINING ===========")
    model = train(args, model, True) #Train encoder
#CAREFUL: for some reason, after training encoder alone, the decoder gets weights=0. 
#We must reinit decoder weights or reload network passing only encoder in order to train decoder
print("========== DECODER TRAINING ===========")
if (not args.state):
    if args.pretrainedEncoder:
        print("Loading encoder pretrained in imagenet")
        from erfnet_imagenet import ERFNet as ERFNet_imagenet
        pretrainedEnc = torch.nn.DataParallel(ERFNet_imagenet(1000))
        pretrainedEnc.load_state_dict(torch.load(args.pretrainedEncoder)['state_dict'])
        pretrainedEnc = next(pretrainedEnc.children()).features.encoder
        if (not args.cuda):
            pretrainedEnc = pretrainedEnc.cpu()     #because loaded encoder is probably saved in cuda
    else:
        pretrainedEnc = next(model.children()).encoder
    model = model_file.Net(NUM_CLASSES, encoder=pretrainedEnc)  #Add decoder to encoder
    if args.cuda:
        model = torch.nn.DataParallel(model).cuda()
    #When loading encoder reinitialize weights for decoder because they are set to 0 when training dec
model = train(args, model, False)   #Train decoder
print("========== TRAINING FINISHED ===========")`

MahnazMesgar avatar Mar 03 '21 13:03 MahnazMesgar