PIDNet icon indicating copy to clipboard operation
PIDNet copied to clipboard

关于模型加载

Open finsary opened this issue 2 years ago • 2 comments

您好,我在加载cityscapes_M_pretrained预训练模型做测试时,发现logger中显示我load 0 parameters,于是我找到下面的代码 if imgnet_pretrained: pretrained_state = torch.load(cfg.MODEL.PRETRAINED, map_location='cpu')['state_dict'] model_dict = model.state_dict() pretrained_state = {k: v for k, v in pretrained_state.items() if (k in model_dict and v.shape == model_dict[k].shape)} model_dict.update(pretrained_state) msg = 'Loaded {} parameters!'.format(len(pretrained_state)) logging.info('Attention!!!') logging.info(msg) logging.info('Over!!!') model.load_state_dict(model_dict, strict = False) else: pretrained_dict = torch.load(cfg.MODEL.PRETRAINED, map_location='cpu') if 'state_dict' in pretrained_dict: pretrained_dict = pretrained_dict['state_dict'] model_dict = model.state_dict() pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if (k[6:] in model_dict and v.shape == model_dict[k[6:]].shape)} msg = 'Loaded {} parameters!'.format(len(pretrained_dict)) logging.info('Attention!!!') logging.info(msg) logging.info('Over!!!') model_dict.update(pretrained_dict) model.load_state_dict(model_dict, strict = False)

对于 imgnet_pretrained为false时, pretrained_dict的k为什么要从序号6开始,这是我load 0 params的原因吗,期待您的解答,谢谢

finsary avatar Nov 08 '22 08:11 finsary

您好,

因为pytorch是按照名字来load权重的,dataparallel出来的名字可能不一样,你可以检查一下,应该前面有个models.

谢谢关注

XuJiacong avatar Nov 08 '22 17:11 XuJiacong

您好,

因为pytorch是按照名字来load权重的,dataparallel出来的名字可能不一样,你可以检查一下,应该前面有个models.

谢谢关注 你好,请问我出现这个错误,是怎么回事 if imgnet_pretrained: pretrained_state = torch.load(cfg.MODEL.PRETRAINED, map_location='cpu')['state_dict'] model_dict = model.state_dict() pretrained_state = {k: v for k, v in pretrained_state.items() if (k in model_dict and v.shape == model_dict[k].shape)} model_dict.update(pretrained_state) msg = 'Loaded {} parameters!'.format(len(pretrained_state)) logging.info('Attention!!!') logging.info(msg) logging.info('Over!!!') model.load_state_dict( Snipaste_2024-03-05_18-02-37 model_dict, strict = False) else: pretrained_dict = torch.load(cfg.MODEL.PRETRAINED, map_location='cpu') if 'state_dict' in pretrained_dict: pretrained_dict = pretrained_dict['state_dict'] model_dict = model.state_dict() pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if (k[6:] in model_dict and v.shape == model_dict[k[6:]].shape)} msg = 'Loaded {} parameters!'.format(len(pretrained_dict)) logging.info('Attention!!!') logging.info(msg) logging.info('Over!!!') model_dict.update(pretrained_dict) model.load_state_dict(model_dict, strict = False)

return model
Snipaste_2024-03-05_18-02-37

QinPeng1 avatar Mar 05 '24 10:03 QinPeng1