PyTorch-StudioGAN icon indicating copy to clipboard operation
PyTorch-StudioGAN copied to clipboard

RuntimeError: Error(s) in loading state_dict for Generator: Missing key(s) in state_dict:

Open goongzi-leean opened this issue 2 years ago • 0 comments

When I load AGGAN-Mod, I get this error:

RuntimeError: Error(s) in loading state_dict for Generator: Missing key(s) in state_dict: "blocks.0.0.bn1.gain.weight", "blocks.0.0.bn1.bias.weight", "blocks.0.0.bn2.gain.weight", "blocks.0.0.bn2.bias.weight", "blocks.1.0.bn1.gain.weight", "blocks.1.0.bn1.bias.weight", "blocks.1.0.bn2.gain.weight", "blocks.1.0.bn2.bias.weight", "blocks.2.0.bn1.gain.weight", "blocks.2.0.bn1.bias.weight", "blocks.2.0.bn2.gain.weight", "blocks.2.0.bn2.bias.weight". Unexpected key(s) in state_dict: "blocks.0.0.bn1.embed0.weight", "blocks.0.0.bn1.embed1.weight", "blocks.0.0.bn2.embed0.weight", "blocks.0.0.bn2.embed1.weight", "blocks.1.0.bn1.embed0.weight", "blocks.1.0.bn1.embed1.weight", "blocks.1.0.bn2.embed0.weight", "blocks.1.0.bn2.embed1.weight", "blocks.2.0.bn1.embed0.weight", "blocks.2.0.bn1.embed1.weight", "blocks.2.0.bn2.embed0.weight", "blocks.2.0.bn2.embed1.weight".

So I went to find out why. The network structure in which my generator was found looks like this:

Generator( (linear0): Linear(in_features=128, out_features=4096, bias=True) (blocks): ModuleList( (0): ModuleList( (0): GenBlock( (bn1): ConditionalBatchNorm2d( (bn): BatchNorm2d(256, eps=0.0001, momentum=0.1, affine=False, track_running_stats=True) (gain): Linear(in_features=10, out_features=256, bias=False) (bias): Linear(in_features=10, out_features=256, bias=False) ) (bn2): ConditionalBatchNorm2d( (bn): BatchNorm2d(256, eps=0.0001, momentum=0.1, affine=False, track_running_stats=True) (gain): Linear(in_features=10, out_features=256, bias=False) (bias): Linear(in_features=10, out_features=256, bias=False) ) (activation): ReLU(inplace=True) (conv2d0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1)) (conv2d1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv2d2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) )

And the log that the author trained looks like this:

Generator( (linear0): Linear(in_features=128, out_features=4096, bias=True) (blocks): ModuleList( (0): ModuleList( (0): GenBlock( (bn1): ConditionalBatchNorm2d( (bn): BatchNorm2d(256, eps=0.0001, momentum=0.1, affine=False, track_running_stats=True) (embed0): Embedding(10, 256) (embed1): Embedding(10, 256) ) (bn2): ConditionalBatchNorm2d( (bn): BatchNorm2d(256, eps=0.0001, momentum=0.1, affine=False, track_running_stats=True) (embed0): Embedding(10, 256) (embed1): Embedding(10, 256) ) (activation): ReLU(inplace=True) (conv2d0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1)) (conv2d1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv2d2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) )

I found ConditionalBatchNorm2d (in ops.py) in the latest code and found:

self.gain = MODULES.g_linear(in_features=in_features, out_features=out_features, bias=False) self.bias = MODULES.g_linear(in_features=in_features, out_features=out_features, bias=False)

but g_linear= ops.linear(in config.py)

This is where the above error comes in.

ConditionalBatchNorm2d will need to be modified if a load author pre-trained generator is required. Or you can choose to retrain. This is true for all conditions GAN.

Of course, I hope the author can pay attention to this problem.

Best!

Leean

goongzi-leean avatar Aug 02 '22 16:08 goongzi-leean