CheXNet
CheXNet copied to clipboard
model.load_state_dict(checkpoint['state_dict']) error with pytorch 0.4.0
I was running the code without any problem on pytorch 0.3.0. I upgraded yesterday to pytorch 0.4.0 and can't load the checkpoint file. I am on Ubuntu and python 3.6 in conda env. I get this error:
RuntimeError Traceback (most recent call last)
--> 182 main()
~/anaconda3/envs/fastai/lib/python3.6/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict) 719 if len(error_msgs) > 0: 720 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( --> 721 self.class.name, "\n\t".join(error_msgs))) 722 723 def parameters(self):
RuntimeError: Error(s) in loading state_dict for DenseNet121: Missing key(s) in state_dict: "densenet121.features.conv0.weight", "densenet121.features.norm0.weight", "densenet121.features.norm0.bias", "densenet121.features.norm0.running_mean", "densenet121.features.norm0.running_var", "densenet121.features.denseblock1.denselayer1.norm1.weight", "densenet121.features.denseblock1.denselayer1.norm1.bias", "densenet121.features.denseblock1.denselayer1.norm1.running_mean", (entire network ...) "module.densenet121.features.denseblock4.denselayer16.conv.2.weight", "module.densenet121.features.norm5.weight", "module.densenet121.features.norm5.bias", "module.densenet121.features.norm5.running_mean", "module.densenet121.features.norm5.running_var", "module.densenet121.classifier.0.weight", "module.densenet121.classifier.0.bias".
It is likely related to this information about pytorch 0.4.0: https://pytorch.org/2018/04/22/0_4_0-migration-guide.html New edge-case constraints on names of submodules, parameters, and buffers in nn.Module name that is an empty string or contains "." is no longer permitted in module.add_module(name, value), module.add_parameter(name, value) or module.add_buffer(name, value) because such names may cause lost data in the state_dict. If you are loading a checkpoint for modules containing such names, please update the module definition and patch the state_dict before loading it.
i hv the same problem too. @alexandrecc , do u hv any solution so far?
You should be able to make something like this work.
import re
# Code modified from torchvision densenet source for loading from pre .4 densenet weights.
checkpoint = torch.load('./model.pth.tar')
state_dict = checkpoint['state_dict']
remove_data_parallel = False # Change if you don't want to use nn.DataParallel(model)
pattern = re.compile(
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
for key in list(state_dict.keys()):
match = pattern.match(key)
new_key = match.group(1) + match.group(2) if match else key
new_key = new_key[7:] if remove_data_parallel else new_key
state_dict[new_key] = state_dict[key]
# Delete old key only if modified.
if match or remove_data_parallel:
del state_dict[key]
Thanks JasperJenkins... This worked but I received another error in the form:
Traceback (most recent call last):
File "C:/Users/Nasir Isa/Documents/1Research/algortihm/CheXNet-master/CheXNet-master/m1.py", line 142, in
Same Here
+1 Seeing this error as well, has torch implemented some way to ensure backwards compatibility when parsing older models? I can't seem to find anything and I would rather not change the keys themselves since that seems quite error prone.
Thanks JasperJenkins... This worked but I received another error in the form: Traceback (most recent call last): File "C:/Users/Nasir Isa/Documents/1Research/algortihm/CheXNet-master/CheXNet-master/m1.py", line 142, in main() File "C:/Users/Nasir Isa/Documents/1Research/algortihm/CheXNet-master/CheXNet-master/m1.py", line 83, in main for i, (inp, target) in enumerate(test_loader): File "C:\Users\Nasir Isa\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\utils\data\dataloader.py", line 501, in iter return _DataLoaderIter(self) File "C:\Users\Nasir Isa\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\utils\data\dataloader.py", line 289, in init w.start() File "C:\Users\Nasir Isa\AppData\Local\Programs\Python\Python36\lib\multiprocessing\process.py", line 105, in start self._popen = self._Popen(self) File "C:\Users\Nasir Isa\AppData\Local\Programs\Python\Python36\lib\multiprocessing\context.py", line 223, in _Popen return _default_context.get_context().Process._Popen(process_obj) File "C:\Users\Nasir Isa\AppData\Local\Programs\Python\Python36\lib\multiprocessing\context.py", line 322, in _Popen return Popen(process_obj) File "C:\Users\Nasir Isa\AppData\Local\Programs\Python\Python36\lib\multiprocessing\popen_spawn_win32.py", line 65, in init reduction.dump(process_obj, to_child) File "C:\Users\Nasir Isa\AppData\Local\Programs\Python\Python36\lib\multiprocessing\reduction.py", line 60, in dump ForkingPickler(file, protocol).dump(obj) AttributeError: Can't pickle local object 'main..' Please help me out!
Set num_workers=0 and try again to find the real issue.
It's late to mentioned here, but testing started you can just put with checkpoint loading strick = False
like this.
model.load_state_dict(checkpoint['state_dict'])
model.load_state_dict(checkpoint['state_dict'], strict=False )
This Worked for me:
state_dict = checkpoint['state_dict'] from collections import OrderedDict new_state_dict = OrderedDict()
for k, v in state_dict.items(): if 'module' not in k: k = 'module.'+k else: k = k.replace( 'module.densenet121.features', 'features') k = k.replace( 'module.densenet121.classifier', 'classifier') k = k.replace( '.norm.1', '.norm1') k = k.replace( '.conv.1', '.conv1') k = k.replace( '.norm.2', '.norm2') k = k.replace( '.conv.2', '.conv2') new_state_dict[k]=v
model.load_state_dict(new_state_dict)
It's late to mentioned here, but testing started you can just put with checkpoint loading
strick = False
like this.model.load_state_dict(checkpoint['state_dict'])
model.load_state_dict(checkpoint['state_dict'], strict=False )
did it work?
It's late to mentioned here, but testing started you can just put with checkpoint loading
strick = False
like this.model.load_state_dict(checkpoint['state_dict'])
model.load_state_dict(checkpoint['state_dict'], strict=False )
did it work?
yes in my case it's worked