EfficientNet-PyTorch
EfficientNet-PyTorch copied to clipboard
_fc.weight error when loading model for transfer learning
python 3.8.0
efficientnet-pytorch==0.7.0
I'm attempting to use EfficientNet for a transfer learning task.
I am loading the model like:
model = EfficientNet.from_pretrained("efficientnet-b0", num_classes=2)
I'm training on multiple GPUs and I don't freeze any of the layers. Everything goes as expected, with reported metrics showing improvement over the training epochs.
Then, I save the model like:
torch.save(model.state_dict(), MODEL_FILEPATH)
While I am training on multiple GPUs, I am attempting to run on CPU, so I had to make the following modification in EfficientNet.utils.load_pretrained_weights():
if isinstance(weights_path, str):
if torch.cuda.is_available():
state_dict = torch.load(weights_path)
else:
state_dict = torch.load(weights_path, map_location=torch.device("cpu"))
In my run script, I attempt to load the model with a custom .pth file like:
model = EfficientNet.from_pretrained("efficientnet-b0", num_classes=2, weights_path="path/to/custom.pth")
However, this finally gives me the error:
File "/path/to/modified/utils.py", line 606, in load_pretrained_weights
state_dict.pop('_fc.weight')
KeyError: '_fc.weight'
In the unmodified file, it's line 603. This is also the original source function:
def load_pretrained_weights(model, model_name, weights_path=None, load_fc=True, advprop=False):
"""Loads pretrained weights from weights path or download using url.
Args:
model (Module): The whole model of efficientnet.
model_name (str): Model name of efficientnet.
weights_path (None or str):
str: path to pretrained weights file on the local disk.
None: use pretrained weights downloaded from the Internet.
load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model.
advprop (bool): Whether to load pretrained weights
trained with advprop (valid when weights_path is None).
"""
if isinstance(weights_path, str):
state_dict = torch.load(weights_path)
else:
# AutoAugment or Advprop (different preprocessing)
url_map_ = url_map_advprop if advprop else url_map
state_dict = model_zoo.load_url(url_map_[model_name])
if load_fc:
ret = model.load_state_dict(state_dict, strict=False)
assert not ret.missing_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys)
else:
state_dict.pop('_fc.weight')
state_dict.pop('_fc.bias')
ret = model.load_state_dict(state_dict, strict=False)
assert set(ret.missing_keys) == set(
['_fc.weight', '_fc.bias']), 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys)
assert not ret.unexpected_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.unexpected_keys)
print('Loaded pretrained weights for {}'.format(model_name))
The error indicates that _fc.weight is missing from the dict, but I'm not exactly sure why it would be given that I saved the model using .state_dict(). It occured to me that since the state_dict is actually expecting 2 classes (not 1000), I don't what the function to perform the if load_fc bool as implemented because it will remove the last layers. So I simply commented out that portion of the code so that the following would run:
ret = model.load_state_dict(state_dict, strict=False)
assert not ret.missing_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys)
But this gave me the error:
AssertionError: Missing keys when loading pretrained weights: ['_conv_stem.weight', '_bn0.weight', '_bn0.bias', '_bn0.running_mean', '_bn0.running_var', '_blocks.0._depthwise_conv.weight', '_blocks.0._bn1.weight', '_blocks.0._bn1.bias', '_blocks.0._bn1.running_mean', '_blocks.0._bn1.running_var', '_blocks.0._se_reduce.weight', '_blocks.0._se_reduce.bias', '_blocks.0._se_expand.weight', '_blocks.0._se_expand.bias', '_blocks.0._project_conv.weight', '_blocks.0._bn2.weight', '_blocks.0._bn2.bias', '_blocks.0._bn2.running_mean', '_blocks.0._bn2.running_var', '_blocks.1._expand_conv.weight', '_blocks.1._bn0.weight', '_blocks.1._bn0.bias', '_blocks.1._bn0.running_mean', '_blocks.1._bn0.running_var', '_blocks.1._depthwise_conv.weight', '_blocks.1._bn1.weight', '_blocks.1._bn1.bias', '_blocks.1._bn1.running_mean', '_blocks.1._bn1.running_var', '_blocks.1._se_reduce.weight', '_blocks.1._se_reduce.bias', '_blocks.1._se_expand.weight', '_blocks.1._se_expand.bias', '_blocks.1._project_conv.weight', '_blocks.1._bn2.weight', '_blocks.1._bn2.bias', '_blocks.1._bn2.running_mean', '_blocks.1._bn2.running_var', '_blocks.2._expand_conv.weight', '_blocks.2._bn0.weight', '_blocks.2._bn0.bias', '_blocks.2._bn0.running_mean', '_blocks.2._bn0.running_var', '_blocks.2._depthwise_conv.weight', '_blocks.2._bn1.weight', '_blocks.2._bn1.bias', '_blocks.2._bn1.running_mean', '_blocks.2._bn1.running_var', '_blocks.2._se_reduce.weight', '_blocks.2._se_reduce.bias', '_blocks.2._se_expand.weight', '_blocks.2._se_expand.bias', '_blocks.2._project_conv.weight', '_blocks.2._bn2.weight', '_blocks.2._bn2.bias', '_blocks.2._bn2.running_mean', '_blocks.2._bn2.running_var', '_blocks.3._expand_conv.weight', '_blocks.3._bn0.weight', '_blocks.3._bn0.bias', '_blocks.3._bn0.running_mean', '_blocks.3._bn0.running_var', '_blocks.3._depthwise_conv.weight', '_blocks.3._bn1.weight', '_blocks.3._bn1.bias', '_blocks.3._bn1.running_mean', '_blocks.3._bn1.running_var', '_blocks.3._se_reduce.weight', '_blocks.3._se_reduce.bias', '_blocks.3._se_expand.weight', '_blocks.3._se_expand.bias', '_blocks.3._project_conv.weight', '_blocks.3._bn2.weight', '_blocks.3._bn2.bias', '_blocks.3._bn2.running_mean', '_blocks.3._bn2.running_var', '_blocks.4._expand_conv.weight', '_blocks.4._bn0.weight', '_blocks.4._bn0.bias', '_blocks.4._bn0.running_mean', '_blocks.4._bn0.running_var', '_blocks.4._depthwise_conv.weight', '_blocks.4._bn1.weight', '_blocks.4._bn1.bias', '_blocks.4._bn1.running_mean', '_blocks.4._bn1.running_var', '_blocks.4._se_reduce.weight', '_blocks.4._se_reduce.bias', '_blocks.4._se_expand.weight', '_blocks.4._se_expand.bias', '_blocks.4._project_conv.weight', '_blocks.4._bn2.weight', '_blocks.4._bn2.bias', '_blocks.4._bn2.running_mean', '_blocks.4._bn2.running_var', '_blocks.5._expand_conv.weight', '_blocks.5._bn0.weight', '_blocks.5._bn0.bias', '_blocks.5._bn0.running_mean', '_blocks.5._bn0.running_var', '_blocks.5._depthwise_conv.weight', '_blocks.5._bn1.weight', '_blocks.5._bn1.bias', '_blocks.5._bn1.running_mean', '_blocks.5._bn1.running_var', '_blocks.5._se_reduce.weight', '_blocks.5._se_reduce.bias', '_blocks.5._se_expand.weight', '_blocks.5._se_expand.bias', '_blocks.5._project_conv.weight', '_blocks.5._bn2.weight', '_blocks.5._bn2.bias', '_blocks.5._bn2.running_mean', '_blocks.5._bn2.running_var', '_blocks.6._expand_conv.weight', '_blocks.6._bn0.weight', '_blocks.6._bn0.bias', '_blocks.6._bn0.running_mean', '_blocks.6._bn0.running_var', '_blocks.6._depthwise_conv.weight', '_blocks.6._bn1.weight', '_blocks.6._bn1.bias', '_blocks.6._bn1.running_mean', '_blocks.6._bn1.running_var', '_blocks.6._se_reduce.weight', '_blocks.6._se_reduce.bias', '_blocks.6._se_expand.weight', '_blocks.6._se_expand.bias', '_blocks.6._project_conv.weight', '_blocks.6._bn2.weight', '_blocks.6._bn2.bias', '_blocks.6._bn2.running_mean', '_blocks.6._bn2.running_var', '_blocks.7._expand_conv.weight', '_blocks.7._bn0.weight', '_blocks.7._bn0.bias', '_blocks.7._bn0.running_mean', '_blocks.7._bn0.running_var', '_blocks.7._depthwise_conv.weight', '_blocks.7._bn1.weight', '_blocks.7._bn1.bias', '_blocks.7._bn1.running_mean', '_blocks.7._bn1.running_var', '_blocks.7._se_reduce.weight', '_blocks.7._se_reduce.bias', '_blocks.7._se_expand.weight', '_blocks.7._se_expand.bias', '_blocks.7._project_conv.weight', '_blocks.7._bn2.weight', '_blocks.7._bn2.bias', '_blocks.7._bn2.running_mean', '_blocks.7._bn2.running_var', '_blocks.8._expand_conv.weight', '_blocks.8._bn0.weight', '_blocks.8._bn0.bias', '_blocks.8._bn0.running_mean', '_blocks.8._bn0.running_var', '_blocks.8._depthwise_conv.weight', '_blocks.8._bn1.weight', '_blocks.8._bn1.bias', '_blocks.8._bn1.running_mean', '_blocks.8._bn1.running_var', '_blocks.8._se_reduce.weight', '_blocks.8._se_reduce.bias', '_blocks.8._se_expand.weight', '_blocks.8._se_expand.bias', '_blocks.8._project_conv.weight', '_blocks.8._bn2.weight', '_blocks.8._bn2.bias', '_blocks.8._bn2.running_mean', '_blocks.8._bn2.running_var', '_blocks.9._expand_conv.weight', '_blocks.9._bn0.weight', '_blocks.9._bn0.bias', '_blocks.9._bn0.running_mean', '_blocks.9._bn0.running_var', '_blocks.9._depthwise_conv.weight', '_blocks.9._bn1.weight', '_blocks.9._bn1.bias', '_blocks.9._bn1.running_mean', '_blocks.9._bn1.running_var', '_blocks.9._se_reduce.weight', '_blocks.9._se_reduce.bias', '_blocks.9._se_expand.weight', '_blocks.9._se_expand.bias', '_blocks.9._project_conv.weight', '_blocks.9._bn2.weight', '_blocks.9._bn2.bias', '_blocks.9._bn2.running_mean', '_blocks.9._bn2.running_var', '_blocks.10._expand_conv.weight', '_blocks.10._bn0.weight', '_blocks.10._bn0.bias', '_blocks.10._bn0.running_mean', '_blocks.10._bn0.running_var', '_blocks.10._depthwise_conv.weight', '_blocks.10._bn1.weight', '_blocks.10._bn1.bias', '_blocks.10._bn1.running_mean', '_blocks.10._bn1.running_var', '_blocks.10._se_reduce.weight', '_blocks.10._se_reduce.bias', '_blocks.10._se_expand.weight', '_blocks.10._se_expand.bias', '_blocks.10._project_conv.weight', '_blocks.10._bn2.weight', '_blocks.10._bn2.bias', '_blocks.10._bn2.running_mean', '_blocks.10._bn2.running_var', '_blocks.11._expand_conv.weight', '_blocks.11._bn0.weight', '_blocks.11._bn0.bias', '_blocks.11._bn0.running_mean', '_blocks.11._bn0.running_var', '_blocks.11._depthwise_conv.weight', '_blocks.11._bn1.weight', '_blocks.11._bn1.bias', '_blocks.11._bn1.running_mean', '_blocks.11._bn1.running_var', '_blocks.11._se_reduce.weight', '_blocks.11._se_reduce.bias', '_blocks.11._se_expand.weight', '_blocks.11._se_expand.bias', '_blocks.11._project_conv.weight', '_blocks.11._bn2.weight', '_blocks.11._bn2.bias', '_blocks.11._bn2.running_mean', '_blocks.11._bn2.running_var', '_blocks.12._expand_conv.weight', '_blocks.12._bn0.weight', '_blocks.12._bn0.bias', '_blocks.12._bn0.running_mean', '_blocks.12._bn0.running_var', '_blocks.12._depthwise_conv.weight', '_blocks.12._bn1.weight', '_blocks.12._bn1.bias', '_blocks.12._bn1.running_mean', '_blocks.12._bn1.running_var', '_blocks.12._se_reduce.weight', '_blocks.12._se_reduce.bias', '_blocks.12._se_expand.weight', '_blocks.12._se_expand.bias', '_blocks.12._project_conv.weight', '_blocks.12._bn2.weight', '_blocks.12._bn2.bias', '_blocks.12._bn2.running_mean', '_blocks.12._bn2.running_var', '_blocks.13._expand_conv.weight', '_blocks.13._bn0.weight', '_blocks.13._bn0.bias', '_blocks.13._bn0.running_mean', '_blocks.13._bn0.running_var', '_blocks.13._depthwise_conv.weight', '_blocks.13._bn1.weight', '_blocks.13._bn1.bias', '_blocks.13._bn1.running_mean', '_blocks.13._bn1.running_var', '_blocks.13._se_reduce.weight', '_blocks.13._se_reduce.bias', '_blocks.13._se_expand.weight', '_blocks.13._se_expand.bias', '_blocks.13._project_conv.weight', '_blocks.13._bn2.weight', '_blocks.13._bn2.bias', '_blocks.13._bn2.running_mean', '_blocks.13._bn2.running_var', '_blocks.14._expand_conv.weight', '_blocks.14._bn0.weight', '_blocks.14._bn0.bias', '_blocks.14._bn0.running_mean', '_blocks.14._bn0.running_var', '_blocks.14._depthwise_conv.weight', '_blocks.14._bn1.weight', '_blocks.14._bn1.bias', '_blocks.14._bn1.running_mean', '_blocks.14._bn1.running_var', '_blocks.14._se_reduce.weight', '_blocks.14._se_reduce.bias', '_blocks.14._se_expand.weight', '_blocks.14._se_expand.bias', '_blocks.14._project_conv.weight', '_blocks.14._bn2.weight', '_blocks.14._bn2.bias', '_blocks.14._bn2.running_mean', '_blocks.14._bn2.running_var', '_blocks.15._expand_conv.weight', '_blocks.15._bn0.weight', '_blocks.15._bn0.bias', '_blocks.15._bn0.running_mean', '_blocks.15._bn0.running_var', '_blocks.15._depthwise_conv.weight', '_blocks.15._bn1.weight', '_blocks.15._bn1.bias', '_blocks.15._bn1.running_mean', '_blocks.15._bn1.running_var', '_blocks.15._se_reduce.weight', '_blocks.15._se_reduce.bias', '_blocks.15._se_expand.weight', '_blocks.15._se_expand.bias', '_blocks.15._project_conv.weight', '_blocks.15._bn2.weight', '_blocks.15._bn2.bias', '_blocks.15._bn2.running_mean', '_blocks.15._bn2.running_var', '_conv_head.weight', '_bn1.weight', '_bn1.bias', '_bn1.running_mean', '_bn1.running_var', '_fc.weight', '_fc.bias']
Based on the recommendation that passing num_classes, and from what I've read in the docs and the issues, I wouldn't expect that I'd have to make any other modifications to the network to load a fine-tuned model. If I do, that's ok, I'd be happy simply to figure out what the solution is.
Thanks!
Also, I believe this issue is related, but the issue was a bit vague so I created a new issue.
Thanks for the very comprehensive issue. My apologies for the late response. I'm having trouble reproducing the issue. I made a Google Colab with a minimal example:
https://colab.research.google.com/drive/1JLisFMRI8rxm5TUvbcUmZBmAKUcohyvF?usp=sharing
That said, you can always try loading with model.load_state_dict, as in:
state_dict = torch.load(weights_path, map_location=torch.device("cpu"))
model = EfficientNet.from_pretrained("efficientnet-b0", num_classes=2)
model.load_state_dict(state_dict)
I have the same problem with you , because I want to extract fc features 1024 like inceptionV3, so I change the efficientnet-b4' fc 1792 to 1024, It will be ok when training, but when I eval it , occures the problems the same with you . please help me , hope you hanve fun.