PytorchStyleFormer
PytorchStyleFormer copied to clipboard
Pretrained model have problems
def resume_eval(self, trained_generator): # 在test的时候都要用什么。。 state_dict = torch.load(trained_generator) self.model.load_state_dict(state_dict['a']) # error here self.decoder.load_state_dict(state_dict['b']) return 0
Exception has occurred: RuntimeError
Error(s) in loading state_dict for DataParallel:
size mismatch for module.coeffs.local_features.0.conv.weight: copying a param with shape torch.Size([128, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
size mismatch for module.coeffs.local_features.0.conv.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for module.coeffs.local_features.1.conv.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
size mismatch for module.coeffs.conv_out.conv.weight: copying a param with shape torch.Size([17408, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([34816, 256, 1, 1]).
size mismatch for module.coeffs.conv_out.conv.bias: copying a param with shape torch.Size([17408]) from checkpoint, the shape in current model is torch.Size([34816]).
size mismatch for module.att.convsr.conv.weight: copying a param with shape torch.Size([1088, 1088, 3, 3]) from checkpoint, the shape in current model is torch.Size([2176, 2176, 3, 3]).
size mismatch for module.att.convsr.conv.bias: copying a param with shape torch.Size([1088]) from checkpoint, the shape in current model is torch.Size([2176]).
File "/home/wangyuxi/codes/PytorchStyleFormer/model.py", line 128, in resume_eval
self.model.load_state_dict(state_dict['a'], strict=False)
File "/home/wangyuxi/codes/PytorchStyleFormer/test.py", line 116, in
I've met the same problem. It seems there is something wrong with pretrained checkpoints or codes, but I still cannot locate it. Thanks for replying me if there is any solution.
I met the same problem.
I have checked the network parameters and changed the values of two input arguments, then this error seems to be fixed (but I don't know why). In test.py, I changed luma_bins (line 86) from 8 to 4 and n_input_size (line 90) from 64 to 32, as
parser.add_argument('--luma_bins', type=int, default=4)
parser.add_argument('--n_input_size', type=int, default=32)
After that, the pre-trained checkpoints can be loaded correctly.
I have checked the network parameters and changed the values of two input arguments, then this error seems to be fixed (but I don't know why). In test.py, I changed luma_bins (line 86) from 8 to 4 and n_input_size (line 90) from 64 to 32, as
parser.add_argument('--luma_bins', type=int, default=4) parser.add_argument('--n_input_size', type=int, default=32)
After that, the pre-trained checkpoints can be loaded correctly.
Great! u are right!