PIDNet
PIDNet copied to clipboard
关于模型加载
您好,我在加载cityscapes_M_pretrained预训练模型做测试时,发现logger中显示我load 0 parameters,于是我找到下面的代码 if imgnet_pretrained: pretrained_state = torch.load(cfg.MODEL.PRETRAINED, map_location='cpu')['state_dict'] model_dict = model.state_dict() pretrained_state = {k: v for k, v in pretrained_state.items() if (k in model_dict and v.shape == model_dict[k].shape)} model_dict.update(pretrained_state) msg = 'Loaded {} parameters!'.format(len(pretrained_state)) logging.info('Attention!!!') logging.info(msg) logging.info('Over!!!') model.load_state_dict(model_dict, strict = False) else: pretrained_dict = torch.load(cfg.MODEL.PRETRAINED, map_location='cpu') if 'state_dict' in pretrained_dict: pretrained_dict = pretrained_dict['state_dict'] model_dict = model.state_dict() pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if (k[6:] in model_dict and v.shape == model_dict[k[6:]].shape)} msg = 'Loaded {} parameters!'.format(len(pretrained_dict)) logging.info('Attention!!!') logging.info(msg) logging.info('Over!!!') model_dict.update(pretrained_dict) model.load_state_dict(model_dict, strict = False)
对于 imgnet_pretrained为false时, pretrained_dict的k为什么要从序号6开始,这是我load 0 params的原因吗,期待您的解答,谢谢
您好,
因为pytorch是按照名字来load权重的,dataparallel出来的名字可能不一样,你可以检查一下,应该前面有个models.
谢谢关注
您好,
因为pytorch是按照名字来load权重的,dataparallel出来的名字可能不一样,你可以检查一下,应该前面有个models.
谢谢关注 你好,请问我出现这个错误,是怎么回事 if imgnet_pretrained: pretrained_state = torch.load(cfg.MODEL.PRETRAINED, map_location='cpu')['state_dict'] model_dict = model.state_dict() pretrained_state = {k: v for k, v in pretrained_state.items() if (k in model_dict and v.shape == model_dict[k].shape)} model_dict.update(pretrained_state) msg = 'Loaded {} parameters!'.format(len(pretrained_state)) logging.info('Attention!!!') logging.info(msg) logging.info('Over!!!') model.load_state_dict(
model_dict, strict = False) else: pretrained_dict = torch.load(cfg.MODEL.PRETRAINED, map_location='cpu') if 'state_dict' in pretrained_dict: pretrained_dict = pretrained_dict['state_dict'] model_dict = model.state_dict() pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if (k[6:] in model_dict and v.shape == model_dict[k[6:]].shape)} msg = 'Loaded {} parameters!'.format(len(pretrained_dict)) logging.info('Attention!!!') logging.info(msg) logging.info('Over!!!') model_dict.update(pretrained_dict) model.load_state_dict(model_dict, strict = False)
return model