CheXNet icon indicating copy to clipboard operation
CheXNet copied to clipboard

model.load_state_dict(checkpoint['state_dict']) error with pytorch 0.4.0

Open alexandrecc opened this issue 6 years ago • 10 comments

I was running the code without any problem on pytorch 0.3.0. I upgraded yesterday to pytorch 0.4.0 and can't load the checkpoint file. I am on Ubuntu and python 3.6 in conda env. I get this error:

RuntimeError Traceback (most recent call last) in () 181 if name == 'main':
--> 182 main()

in main() 39 print("=> loading checkpoint") 40 checkpoint = torch.load(CKPT_PATH) ---> 41 model.load_state_dict(checkpoint['state_dict']) 42 print("=> loaded checkpoint") 43 else:

~/anaconda3/envs/fastai/lib/python3.6/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict) 719 if len(error_msgs) > 0: 720 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( --> 721 self.class.name, "\n\t".join(error_msgs))) 722 723 def parameters(self):

RuntimeError: Error(s) in loading state_dict for DenseNet121: Missing key(s) in state_dict: "densenet121.features.conv0.weight", "densenet121.features.norm0.weight", "densenet121.features.norm0.bias", "densenet121.features.norm0.running_mean", "densenet121.features.norm0.running_var", "densenet121.features.denseblock1.denselayer1.norm1.weight", "densenet121.features.denseblock1.denselayer1.norm1.bias", "densenet121.features.denseblock1.denselayer1.norm1.running_mean", (entire network ...) "module.densenet121.features.denseblock4.denselayer16.conv.2.weight", "module.densenet121.features.norm5.weight", "module.densenet121.features.norm5.bias", "module.densenet121.features.norm5.running_mean", "module.densenet121.features.norm5.running_var", "module.densenet121.classifier.0.weight", "module.densenet121.classifier.0.bias".

It is likely related to this information about pytorch 0.4.0: https://pytorch.org/2018/04/22/0_4_0-migration-guide.html New edge-case constraints on names of submodules, parameters, and buffers in nn.Module name that is an empty string or contains "." is no longer permitted in module.add_module(name, value), module.add_parameter(name, value) or module.add_buffer(name, value) because such names may cause lost data in the state_dict. If you are loading a checkpoint for modules containing such names, please update the module definition and patch the state_dict before loading it.

alexandrecc avatar May 26 '18 00:05 alexandrecc

i hv the same problem too. @alexandrecc , do u hv any solution so far?

drfikrah avatar Sep 12 '18 16:09 drfikrah

You should be able to make something like this work.

import re

# Code modified from torchvision densenet source for loading from pre .4 densenet weights.
checkpoint = torch.load('./model.pth.tar')
state_dict = checkpoint['state_dict']
remove_data_parallel = False # Change if you don't want to use nn.DataParallel(model)

pattern = re.compile(
    r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
for key in list(state_dict.keys()):
    match = pattern.match(key)
    new_key = match.group(1) + match.group(2) if match else key
    new_key = new_key[7:] if remove_data_parallel else new_key
    state_dict[new_key] = state_dict[key]
    # Delete old key only if modified.
    if match or remove_data_parallel: 
        del state_dict[key]

ghost avatar Sep 14 '18 17:09 ghost

Thanks JasperJenkins... This worked but I received another error in the form: Traceback (most recent call last): File "C:/Users/Nasir Isa/Documents/1Research/algortihm/CheXNet-master/CheXNet-master/m1.py", line 142, in main() File "C:/Users/Nasir Isa/Documents/1Research/algortihm/CheXNet-master/CheXNet-master/m1.py", line 83, in main for i, (inp, target) in enumerate(test_loader): File "C:\Users\Nasir Isa\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\utils\data\dataloader.py", line 501, in iter return _DataLoaderIter(self) File "C:\Users\Nasir Isa\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\utils\data\dataloader.py", line 289, in init w.start() File "C:\Users\Nasir Isa\AppData\Local\Programs\Python\Python36\lib\multiprocessing\process.py", line 105, in start self._popen = self._Popen(self) File "C:\Users\Nasir Isa\AppData\Local\Programs\Python\Python36\lib\multiprocessing\context.py", line 223, in _Popen return _default_context.get_context().Process._Popen(process_obj) File "C:\Users\Nasir Isa\AppData\Local\Programs\Python\Python36\lib\multiprocessing\context.py", line 322, in _Popen return Popen(process_obj) File "C:\Users\Nasir Isa\AppData\Local\Programs\Python\Python36\lib\multiprocessing\popen_spawn_win32.py", line 65, in init reduction.dump(process_obj, to_child) File "C:\Users\Nasir Isa\AppData\Local\Programs\Python\Python36\lib\multiprocessing\reduction.py", line 60, in dump ForkingPickler(file, protocol).dump(obj) AttributeError: Can't pickle local object 'main..' Please help me out!

pharouknucleus avatar Sep 16 '18 01:09 pharouknucleus

Same Here

sukumargaonkar avatar Mar 09 '19 21:03 sukumargaonkar

+1 Seeing this error as well, has torch implemented some way to ensure backwards compatibility when parsing older models? I can't seem to find anything and I would rather not change the keys themselves since that seems quite error prone.

robhyb19 avatar May 05 '19 00:05 robhyb19

Thanks JasperJenkins... This worked but I received another error in the form: Traceback (most recent call last): File "C:/Users/Nasir Isa/Documents/1Research/algortihm/CheXNet-master/CheXNet-master/m1.py", line 142, in main() File "C:/Users/Nasir Isa/Documents/1Research/algortihm/CheXNet-master/CheXNet-master/m1.py", line 83, in main for i, (inp, target) in enumerate(test_loader): File "C:\Users\Nasir Isa\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\utils\data\dataloader.py", line 501, in iter return _DataLoaderIter(self) File "C:\Users\Nasir Isa\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\utils\data\dataloader.py", line 289, in init w.start() File "C:\Users\Nasir Isa\AppData\Local\Programs\Python\Python36\lib\multiprocessing\process.py", line 105, in start self._popen = self._Popen(self) File "C:\Users\Nasir Isa\AppData\Local\Programs\Python\Python36\lib\multiprocessing\context.py", line 223, in _Popen return _default_context.get_context().Process._Popen(process_obj) File "C:\Users\Nasir Isa\AppData\Local\Programs\Python\Python36\lib\multiprocessing\context.py", line 322, in _Popen return Popen(process_obj) File "C:\Users\Nasir Isa\AppData\Local\Programs\Python\Python36\lib\multiprocessing\popen_spawn_win32.py", line 65, in init reduction.dump(process_obj, to_child) File "C:\Users\Nasir Isa\AppData\Local\Programs\Python\Python36\lib\multiprocessing\reduction.py", line 60, in dump ForkingPickler(file, protocol).dump(obj) AttributeError: Can't pickle local object 'main..' Please help me out!

Set num_workers=0 and try again to find the real issue.

agil27 avatar Jul 15 '20 06:07 agil27

It's late to mentioned here, but testing started you can just put with checkpoint loading strick = False like this. model.load_state_dict(checkpoint['state_dict']) model.load_state_dict(checkpoint['state_dict'], strict=False )

Aliktk avatar Jun 01 '21 05:06 Aliktk

This Worked for me:

state_dict = checkpoint['state_dict'] from collections import OrderedDict new_state_dict = OrderedDict()

for k, v in state_dict.items(): if 'module' not in k: k = 'module.'+k else: k = k.replace( 'module.densenet121.features', 'features') k = k.replace( 'module.densenet121.classifier', 'classifier') k = k.replace( '.norm.1', '.norm1') k = k.replace( '.conv.1', '.conv1') k = k.replace( '.norm.2', '.norm2') k = k.replace( '.conv.2', '.conv2') new_state_dict[k]=v

model.load_state_dict(new_state_dict)

Ibrahmm avatar Oct 18 '21 06:10 Ibrahmm

It's late to mentioned here, but testing started you can just put with checkpoint loading strick = False like this. model.load_state_dict(checkpoint['state_dict']) model.load_state_dict(checkpoint['state_dict'], strict=False )

did it work?

taherpat avatar Jan 02 '24 15:01 taherpat

It's late to mentioned here, but testing started you can just put with checkpoint loading strick = False like this. model.load_state_dict(checkpoint['state_dict']) model.load_state_dict(checkpoint['state_dict'], strict=False )

did it work?

yes in my case it's worked

Aliktk avatar Jan 15 '24 06:01 Aliktk