PyTorch-StudioGAN
PyTorch-StudioGAN copied to clipboard
RuntimeError: Error(s) in loading state_dict for Generator: Missing key(s) in state_dict:
When I load AGGAN-Mod, I get this error:
RuntimeError: Error(s) in loading state_dict for Generator: Missing key(s) in state_dict: "blocks.0.0.bn1.gain.weight", "blocks.0.0.bn1.bias.weight", "blocks.0.0.bn2.gain.weight", "blocks.0.0.bn2.bias.weight", "blocks.1.0.bn1.gain.weight", "blocks.1.0.bn1.bias.weight", "blocks.1.0.bn2.gain.weight", "blocks.1.0.bn2.bias.weight", "blocks.2.0.bn1.gain.weight", "blocks.2.0.bn1.bias.weight", "blocks.2.0.bn2.gain.weight", "blocks.2.0.bn2.bias.weight". Unexpected key(s) in state_dict: "blocks.0.0.bn1.embed0.weight", "blocks.0.0.bn1.embed1.weight", "blocks.0.0.bn2.embed0.weight", "blocks.0.0.bn2.embed1.weight", "blocks.1.0.bn1.embed0.weight", "blocks.1.0.bn1.embed1.weight", "blocks.1.0.bn2.embed0.weight", "blocks.1.0.bn2.embed1.weight", "blocks.2.0.bn1.embed0.weight", "blocks.2.0.bn1.embed1.weight", "blocks.2.0.bn2.embed0.weight", "blocks.2.0.bn2.embed1.weight".
So I went to find out why. The network structure in which my generator was found looks like this:
Generator( (linear0): Linear(in_features=128, out_features=4096, bias=True) (blocks): ModuleList( (0): ModuleList( (0): GenBlock( (bn1): ConditionalBatchNorm2d( (bn): BatchNorm2d(256, eps=0.0001, momentum=0.1, affine=False, track_running_stats=True) (gain): Linear(in_features=10, out_features=256, bias=False) (bias): Linear(in_features=10, out_features=256, bias=False) ) (bn2): ConditionalBatchNorm2d( (bn): BatchNorm2d(256, eps=0.0001, momentum=0.1, affine=False, track_running_stats=True) (gain): Linear(in_features=10, out_features=256, bias=False) (bias): Linear(in_features=10, out_features=256, bias=False) ) (activation): ReLU(inplace=True) (conv2d0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1)) (conv2d1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv2d2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) )
And the log that the author trained looks like this:
Generator( (linear0): Linear(in_features=128, out_features=4096, bias=True) (blocks): ModuleList( (0): ModuleList( (0): GenBlock( (bn1): ConditionalBatchNorm2d( (bn): BatchNorm2d(256, eps=0.0001, momentum=0.1, affine=False, track_running_stats=True) (embed0): Embedding(10, 256) (embed1): Embedding(10, 256) ) (bn2): ConditionalBatchNorm2d( (bn): BatchNorm2d(256, eps=0.0001, momentum=0.1, affine=False, track_running_stats=True) (embed0): Embedding(10, 256) (embed1): Embedding(10, 256) ) (activation): ReLU(inplace=True) (conv2d0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1)) (conv2d1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv2d2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) )
I found ConditionalBatchNorm2d (in ops.py) in the latest code and found:
self.gain = MODULES.g_linear(in_features=in_features, out_features=out_features, bias=False) self.bias = MODULES.g_linear(in_features=in_features, out_features=out_features, bias=False)
but g_linear= ops.linear(in config.py)
This is where the above error comes in.
ConditionalBatchNorm2d will need to be modified if a load author pre-trained generator is required. Or you can choose to retrain. This is true for all conditions GAN.
Of course, I hope the author can pay attention to this problem.
Best!
Leean