PytorchStyleFormer icon indicating copy to clipboard operation
PytorchStyleFormer copied to clipboard

Pretrained model have problems

Open NTUYWANG103 opened this issue 2 years ago • 4 comments

def resume_eval(self, trained_generator): # 在test的时候都要用什么。。 state_dict = torch.load(trained_generator) self.model.load_state_dict(state_dict['a']) # error here self.decoder.load_state_dict(state_dict['b']) return 0

Exception has occurred: RuntimeError Error(s) in loading state_dict for DataParallel: size mismatch for module.coeffs.local_features.0.conv.weight: copying a param with shape torch.Size([128, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]). size mismatch for module.coeffs.local_features.0.conv.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]). size mismatch for module.coeffs.local_features.1.conv.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]). size mismatch for module.coeffs.conv_out.conv.weight: copying a param with shape torch.Size([17408, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([34816, 256, 1, 1]). size mismatch for module.coeffs.conv_out.conv.bias: copying a param with shape torch.Size([17408]) from checkpoint, the shape in current model is torch.Size([34816]). size mismatch for module.att.convsr.conv.weight: copying a param with shape torch.Size([1088, 1088, 3, 3]) from checkpoint, the shape in current model is torch.Size([2176, 2176, 3, 3]). size mismatch for module.att.convsr.conv.bias: copying a param with shape torch.Size([1088]) from checkpoint, the shape in current model is torch.Size([2176]). File "/home/wangyuxi/codes/PytorchStyleFormer/model.py", line 128, in resume_eval self.model.load_state_dict(state_dict['a'], strict=False) File "/home/wangyuxi/codes/PytorchStyleFormer/test.py", line 116, in initial_step = myNet.resume_eval(options.trained_network) RuntimeError: Error(s) in loading state_dict for DataParallel: size mismatch for module.coeffs.local_features.0.conv.weight: copying a param with shape torch.Size([128, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]). size mismatch for module.coeffs.local_features.0.conv.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]). size mismatch for module.coeffs.local_features.1.conv.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]). size mismatch for module.coeffs.conv_out.conv.weight: copying a param with shape torch.Size([17408, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([34816, 256, 1, 1]). size mismatch for module.coeffs.conv_out.conv.bias: copying a param with shape torch.Size([17408]) from checkpoint, the shape in current model is torch.Size([34816]). size mismatch for module.att.convsr.conv.weight: copying a param with shape torch.Size([1088, 1088, 3, 3]) from checkpoint, the shape in current model is torch.Size([2176, 2176, 3, 3]). size mismatch for module.att.convsr.conv.bias: copying a param with shape torch.Size([1088]) from checkpoint, the shape in current model is torch.Size([2176]).

NTUYWANG103 avatar Jan 17 '23 11:01 NTUYWANG103

I've met the same problem. It seems there is something wrong with pretrained checkpoints or codes, but I still cannot locate it. Thanks for replying me if there is any solution.

magus-jizx avatar Jan 27 '23 06:01 magus-jizx

I met the same problem.

adaxidedakaonang avatar Apr 18 '23 08:04 adaxidedakaonang

I have checked the network parameters and changed the values of two input arguments, then this error seems to be fixed (but I don't know why). In test.py, I changed luma_bins (line 86) from 8 to 4 and n_input_size (line 90) from 64 to 32, as

parser.add_argument('--luma_bins', type=int, default=4)
parser.add_argument('--n_input_size', type=int, default=32)

After that, the pre-trained checkpoints can be loaded correctly.

LeoLiu0918 avatar May 26 '23 07:05 LeoLiu0918

I have checked the network parameters and changed the values of two input arguments, then this error seems to be fixed (but I don't know why). In test.py, I changed luma_bins (line 86) from 8 to 4 and n_input_size (line 90) from 64 to 32, as

parser.add_argument('--luma_bins', type=int, default=4)
parser.add_argument('--n_input_size', type=int, default=32)

After that, the pre-trained checkpoints can be loaded correctly.

Great! u are right!

weiyang001 avatar Sep 30 '23 09:09 weiyang001