EfficientNet-PyTorch icon indicating copy to clipboard operation
EfficientNet-PyTorch copied to clipboard

_fc.weight error when loading model for transfer learning

Open thommiano opened this issue 4 years ago • 2 comments

python 3.8.0 efficientnet-pytorch==0.7.0

I'm attempting to use EfficientNet for a transfer learning task.

I am loading the model like:

model = EfficientNet.from_pretrained("efficientnet-b0", num_classes=2)

I'm training on multiple GPUs and I don't freeze any of the layers. Everything goes as expected, with reported metrics showing improvement over the training epochs.

Then, I save the model like:

torch.save(model.state_dict(), MODEL_FILEPATH)

While I am training on multiple GPUs, I am attempting to run on CPU, so I had to make the following modification in EfficientNet.utils.load_pretrained_weights():

    if isinstance(weights_path, str):
        if torch.cuda.is_available():
            state_dict = torch.load(weights_path)
        else:
            state_dict = torch.load(weights_path, map_location=torch.device("cpu"))

In my run script, I attempt to load the model with a custom .pth file like:

model = EfficientNet.from_pretrained("efficientnet-b0", num_classes=2, weights_path="path/to/custom.pth")

However, this finally gives me the error:

  File "/path/to/modified/utils.py", line 606, in load_pretrained_weights
    state_dict.pop('_fc.weight')
KeyError: '_fc.weight'

In the unmodified file, it's line 603. This is also the original source function:

def load_pretrained_weights(model, model_name, weights_path=None, load_fc=True, advprop=False):
    """Loads pretrained weights from weights path or download using url.

    Args:
        model (Module): The whole model of efficientnet.
        model_name (str): Model name of efficientnet.
        weights_path (None or str):
            str: path to pretrained weights file on the local disk.
            None: use pretrained weights downloaded from the Internet.
        load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model.
        advprop (bool): Whether to load pretrained weights
                        trained with advprop (valid when weights_path is None).
    """
    if isinstance(weights_path, str):
        state_dict = torch.load(weights_path)
    else:
        # AutoAugment or Advprop (different preprocessing)
        url_map_ = url_map_advprop if advprop else url_map
        state_dict = model_zoo.load_url(url_map_[model_name])

    if load_fc:
        ret = model.load_state_dict(state_dict, strict=False)
        assert not ret.missing_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys)
    else:
        state_dict.pop('_fc.weight')
        state_dict.pop('_fc.bias')
        ret = model.load_state_dict(state_dict, strict=False)
        assert set(ret.missing_keys) == set(
            ['_fc.weight', '_fc.bias']), 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys)
    assert not ret.unexpected_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.unexpected_keys)

    print('Loaded pretrained weights for {}'.format(model_name))

The error indicates that _fc.weight is missing from the dict, but I'm not exactly sure why it would be given that I saved the model using .state_dict(). It occured to me that since the state_dict is actually expecting 2 classes (not 1000), I don't what the function to perform the if load_fc bool as implemented because it will remove the last layers. So I simply commented out that portion of the code so that the following would run:

 ret = model.load_state_dict(state_dict, strict=False)
 assert not ret.missing_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys)

But this gave me the error:

AssertionError: Missing keys when loading pretrained weights: ['_conv_stem.weight', '_bn0.weight', '_bn0.bias', '_bn0.running_mean', '_bn0.running_var', '_blocks.0._depthwise_conv.weight', '_blocks.0._bn1.weight', '_blocks.0._bn1.bias', '_blocks.0._bn1.running_mean', '_blocks.0._bn1.running_var', '_blocks.0._se_reduce.weight', '_blocks.0._se_reduce.bias', '_blocks.0._se_expand.weight', '_blocks.0._se_expand.bias', '_blocks.0._project_conv.weight', '_blocks.0._bn2.weight', '_blocks.0._bn2.bias', '_blocks.0._bn2.running_mean', '_blocks.0._bn2.running_var', '_blocks.1._expand_conv.weight', '_blocks.1._bn0.weight', '_blocks.1._bn0.bias', '_blocks.1._bn0.running_mean', '_blocks.1._bn0.running_var', '_blocks.1._depthwise_conv.weight', '_blocks.1._bn1.weight', '_blocks.1._bn1.bias', '_blocks.1._bn1.running_mean', '_blocks.1._bn1.running_var', '_blocks.1._se_reduce.weight', '_blocks.1._se_reduce.bias', '_blocks.1._se_expand.weight', '_blocks.1._se_expand.bias', '_blocks.1._project_conv.weight', '_blocks.1._bn2.weight', '_blocks.1._bn2.bias', '_blocks.1._bn2.running_mean', '_blocks.1._bn2.running_var', '_blocks.2._expand_conv.weight', '_blocks.2._bn0.weight', '_blocks.2._bn0.bias', '_blocks.2._bn0.running_mean', '_blocks.2._bn0.running_var', '_blocks.2._depthwise_conv.weight', '_blocks.2._bn1.weight', '_blocks.2._bn1.bias', '_blocks.2._bn1.running_mean', '_blocks.2._bn1.running_var', '_blocks.2._se_reduce.weight', '_blocks.2._se_reduce.bias', '_blocks.2._se_expand.weight', '_blocks.2._se_expand.bias', '_blocks.2._project_conv.weight', '_blocks.2._bn2.weight', '_blocks.2._bn2.bias', '_blocks.2._bn2.running_mean', '_blocks.2._bn2.running_var', '_blocks.3._expand_conv.weight', '_blocks.3._bn0.weight', '_blocks.3._bn0.bias', '_blocks.3._bn0.running_mean', '_blocks.3._bn0.running_var', '_blocks.3._depthwise_conv.weight', '_blocks.3._bn1.weight', '_blocks.3._bn1.bias', '_blocks.3._bn1.running_mean', '_blocks.3._bn1.running_var', '_blocks.3._se_reduce.weight', '_blocks.3._se_reduce.bias', '_blocks.3._se_expand.weight', '_blocks.3._se_expand.bias', '_blocks.3._project_conv.weight', '_blocks.3._bn2.weight', '_blocks.3._bn2.bias', '_blocks.3._bn2.running_mean', '_blocks.3._bn2.running_var', '_blocks.4._expand_conv.weight', '_blocks.4._bn0.weight', '_blocks.4._bn0.bias', '_blocks.4._bn0.running_mean', '_blocks.4._bn0.running_var', '_blocks.4._depthwise_conv.weight', '_blocks.4._bn1.weight', '_blocks.4._bn1.bias', '_blocks.4._bn1.running_mean', '_blocks.4._bn1.running_var', '_blocks.4._se_reduce.weight', '_blocks.4._se_reduce.bias', '_blocks.4._se_expand.weight', '_blocks.4._se_expand.bias', '_blocks.4._project_conv.weight', '_blocks.4._bn2.weight', '_blocks.4._bn2.bias', '_blocks.4._bn2.running_mean', '_blocks.4._bn2.running_var', '_blocks.5._expand_conv.weight', '_blocks.5._bn0.weight', '_blocks.5._bn0.bias', '_blocks.5._bn0.running_mean', '_blocks.5._bn0.running_var', '_blocks.5._depthwise_conv.weight', '_blocks.5._bn1.weight', '_blocks.5._bn1.bias', '_blocks.5._bn1.running_mean', '_blocks.5._bn1.running_var', '_blocks.5._se_reduce.weight', '_blocks.5._se_reduce.bias', '_blocks.5._se_expand.weight', '_blocks.5._se_expand.bias', '_blocks.5._project_conv.weight', '_blocks.5._bn2.weight', '_blocks.5._bn2.bias', '_blocks.5._bn2.running_mean', '_blocks.5._bn2.running_var', '_blocks.6._expand_conv.weight', '_blocks.6._bn0.weight', '_blocks.6._bn0.bias', '_blocks.6._bn0.running_mean', '_blocks.6._bn0.running_var', '_blocks.6._depthwise_conv.weight', '_blocks.6._bn1.weight', '_blocks.6._bn1.bias', '_blocks.6._bn1.running_mean', '_blocks.6._bn1.running_var', '_blocks.6._se_reduce.weight', '_blocks.6._se_reduce.bias', '_blocks.6._se_expand.weight', '_blocks.6._se_expand.bias', '_blocks.6._project_conv.weight', '_blocks.6._bn2.weight', '_blocks.6._bn2.bias', '_blocks.6._bn2.running_mean', '_blocks.6._bn2.running_var', '_blocks.7._expand_conv.weight', '_blocks.7._bn0.weight', '_blocks.7._bn0.bias', '_blocks.7._bn0.running_mean', '_blocks.7._bn0.running_var', '_blocks.7._depthwise_conv.weight', '_blocks.7._bn1.weight', '_blocks.7._bn1.bias', '_blocks.7._bn1.running_mean', '_blocks.7._bn1.running_var', '_blocks.7._se_reduce.weight', '_blocks.7._se_reduce.bias', '_blocks.7._se_expand.weight', '_blocks.7._se_expand.bias', '_blocks.7._project_conv.weight', '_blocks.7._bn2.weight', '_blocks.7._bn2.bias', '_blocks.7._bn2.running_mean', '_blocks.7._bn2.running_var', '_blocks.8._expand_conv.weight', '_blocks.8._bn0.weight', '_blocks.8._bn0.bias', '_blocks.8._bn0.running_mean', '_blocks.8._bn0.running_var', '_blocks.8._depthwise_conv.weight', '_blocks.8._bn1.weight', '_blocks.8._bn1.bias', '_blocks.8._bn1.running_mean', '_blocks.8._bn1.running_var', '_blocks.8._se_reduce.weight', '_blocks.8._se_reduce.bias', '_blocks.8._se_expand.weight', '_blocks.8._se_expand.bias', '_blocks.8._project_conv.weight', '_blocks.8._bn2.weight', '_blocks.8._bn2.bias', '_blocks.8._bn2.running_mean', '_blocks.8._bn2.running_var', '_blocks.9._expand_conv.weight', '_blocks.9._bn0.weight', '_blocks.9._bn0.bias', '_blocks.9._bn0.running_mean', '_blocks.9._bn0.running_var', '_blocks.9._depthwise_conv.weight', '_blocks.9._bn1.weight', '_blocks.9._bn1.bias', '_blocks.9._bn1.running_mean', '_blocks.9._bn1.running_var', '_blocks.9._se_reduce.weight', '_blocks.9._se_reduce.bias', '_blocks.9._se_expand.weight', '_blocks.9._se_expand.bias', '_blocks.9._project_conv.weight', '_blocks.9._bn2.weight', '_blocks.9._bn2.bias', '_blocks.9._bn2.running_mean', '_blocks.9._bn2.running_var', '_blocks.10._expand_conv.weight', '_blocks.10._bn0.weight', '_blocks.10._bn0.bias', '_blocks.10._bn0.running_mean', '_blocks.10._bn0.running_var', '_blocks.10._depthwise_conv.weight', '_blocks.10._bn1.weight', '_blocks.10._bn1.bias', '_blocks.10._bn1.running_mean', '_blocks.10._bn1.running_var', '_blocks.10._se_reduce.weight', '_blocks.10._se_reduce.bias', '_blocks.10._se_expand.weight', '_blocks.10._se_expand.bias', '_blocks.10._project_conv.weight', '_blocks.10._bn2.weight', '_blocks.10._bn2.bias', '_blocks.10._bn2.running_mean', '_blocks.10._bn2.running_var', '_blocks.11._expand_conv.weight', '_blocks.11._bn0.weight', '_blocks.11._bn0.bias', '_blocks.11._bn0.running_mean', '_blocks.11._bn0.running_var', '_blocks.11._depthwise_conv.weight', '_blocks.11._bn1.weight', '_blocks.11._bn1.bias', '_blocks.11._bn1.running_mean', '_blocks.11._bn1.running_var', '_blocks.11._se_reduce.weight', '_blocks.11._se_reduce.bias', '_blocks.11._se_expand.weight', '_blocks.11._se_expand.bias', '_blocks.11._project_conv.weight', '_blocks.11._bn2.weight', '_blocks.11._bn2.bias', '_blocks.11._bn2.running_mean', '_blocks.11._bn2.running_var', '_blocks.12._expand_conv.weight', '_blocks.12._bn0.weight', '_blocks.12._bn0.bias', '_blocks.12._bn0.running_mean', '_blocks.12._bn0.running_var', '_blocks.12._depthwise_conv.weight', '_blocks.12._bn1.weight', '_blocks.12._bn1.bias', '_blocks.12._bn1.running_mean', '_blocks.12._bn1.running_var', '_blocks.12._se_reduce.weight', '_blocks.12._se_reduce.bias', '_blocks.12._se_expand.weight', '_blocks.12._se_expand.bias', '_blocks.12._project_conv.weight', '_blocks.12._bn2.weight', '_blocks.12._bn2.bias', '_blocks.12._bn2.running_mean', '_blocks.12._bn2.running_var', '_blocks.13._expand_conv.weight', '_blocks.13._bn0.weight', '_blocks.13._bn0.bias', '_blocks.13._bn0.running_mean', '_blocks.13._bn0.running_var', '_blocks.13._depthwise_conv.weight', '_blocks.13._bn1.weight', '_blocks.13._bn1.bias', '_blocks.13._bn1.running_mean', '_blocks.13._bn1.running_var', '_blocks.13._se_reduce.weight', '_blocks.13._se_reduce.bias', '_blocks.13._se_expand.weight', '_blocks.13._se_expand.bias', '_blocks.13._project_conv.weight', '_blocks.13._bn2.weight', '_blocks.13._bn2.bias', '_blocks.13._bn2.running_mean', '_blocks.13._bn2.running_var', '_blocks.14._expand_conv.weight', '_blocks.14._bn0.weight', '_blocks.14._bn0.bias', '_blocks.14._bn0.running_mean', '_blocks.14._bn0.running_var', '_blocks.14._depthwise_conv.weight', '_blocks.14._bn1.weight', '_blocks.14._bn1.bias', '_blocks.14._bn1.running_mean', '_blocks.14._bn1.running_var', '_blocks.14._se_reduce.weight', '_blocks.14._se_reduce.bias', '_blocks.14._se_expand.weight', '_blocks.14._se_expand.bias', '_blocks.14._project_conv.weight', '_blocks.14._bn2.weight', '_blocks.14._bn2.bias', '_blocks.14._bn2.running_mean', '_blocks.14._bn2.running_var', '_blocks.15._expand_conv.weight', '_blocks.15._bn0.weight', '_blocks.15._bn0.bias', '_blocks.15._bn0.running_mean', '_blocks.15._bn0.running_var', '_blocks.15._depthwise_conv.weight', '_blocks.15._bn1.weight', '_blocks.15._bn1.bias', '_blocks.15._bn1.running_mean', '_blocks.15._bn1.running_var', '_blocks.15._se_reduce.weight', '_blocks.15._se_reduce.bias', '_blocks.15._se_expand.weight', '_blocks.15._se_expand.bias', '_blocks.15._project_conv.weight', '_blocks.15._bn2.weight', '_blocks.15._bn2.bias', '_blocks.15._bn2.running_mean', '_blocks.15._bn2.running_var', '_conv_head.weight', '_bn1.weight', '_bn1.bias', '_bn1.running_mean', '_bn1.running_var', '_fc.weight', '_fc.bias']

Based on the recommendation that passing num_classes, and from what I've read in the docs and the issues, I wouldn't expect that I'd have to make any other modifications to the network to load a fine-tuned model. If I do, that's ok, I'd be happy simply to figure out what the solution is.

Thanks!

Also, I believe this issue is related, but the issue was a bit vague so I created a new issue.

thommiano avatar Mar 31 '21 00:03 thommiano

Thanks for the very comprehensive issue. My apologies for the late response. I'm having trouble reproducing the issue. I made a Google Colab with a minimal example:

https://colab.research.google.com/drive/1JLisFMRI8rxm5TUvbcUmZBmAKUcohyvF?usp=sharing

That said, you can always try loading with model.load_state_dict, as in:

state_dict = torch.load(weights_path, map_location=torch.device("cpu"))
model = EfficientNet.from_pretrained("efficientnet-b0", num_classes=2)
model.load_state_dict(state_dict)

lukemelas avatar Apr 15 '21 13:04 lukemelas

I have the same problem with you , because I want to extract fc features 1024 like inceptionV3, so I change the efficientnet-b4' fc 1792 to 1024, It will be ok when training, but when I eval it , occures the problems the same with you . please help me , hope you hanve fun.

buaacarzp avatar Jun 01 '21 07:06 buaacarzp