GANsNRoses icon indicating copy to clipboard operation
GANsNRoses copied to clipboard

about pretrained mode to finetune

Open 523997931 opened this issue 3 years ago • 1 comments

Hello, I try to use the pretrained model that you provided to finetune, but when I use train script to set the ckpt, it tell me like this: 2021/07/08 20:26:17 File "train.py", line 377, in Loading model from: /opt/conda/lib/python3.7/site-packages/lpips/weights/v0.1/vgg.pth 182021/07/08 20:26:17 G_A2B.load_state_dict(ckpt['G_A2B']) 192021/07/08 20:26:17 Unexpected key(s) in state_dict: "encoder.stem.5.conv1.0.weight", "encoder.stem.5.conv1.1.bias", "encoder.stem.5.conv2.0.weight", "encoder.stem.5.conv2.1.bias", "encoder.stem.4.skip.0.kernel", "encoder.stem.4.skip.1.weight", "encoder.stem.4.conv2.2.bias", "encoder.stem.4.conv2.0.kernel", "encoder.stem.4.conv2.1.weight", "encoder.style.3.weight", "encoder.style.3.bias", "convs.6.conv.weight", "convs.6.conv.blur.kernel", "convs.6.conv.modulation.weight", "convs.6.conv.modulation.bias", "convs.6.activate.bias", "convs.7.conv.weight", "convs.7.conv.modulation.weight", "convs.7.conv.modulation.bias", "convs.7.activate.bias", "to_rgbs.3.bias", "to_rgbs.3.upsample.kernel", "to_rgbs.3.conv.weight", "to_rgbs.3.conv.modulation.weight", "to_rgbs.3.conv.modulation.bias". 202021/07/08 20:26:17 size mismatch for convs.0.conv.weight: copying a param with shape torch.Size([1, 512, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 256, 512, 3, 3]). 212021/07/08 20:26:17 RuntimeError: Error(s) in loading state_dict for Generator: 222021/07/08 20:26:17 size mismatch for encoder.style.4.weight: copying a param with shape torch.Size([8, 512]) from checkpoint, the shape in current model is torch.Size([512, 8192]). 232021/07/08 20:26:17 size mismatch for convs.0.activate.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]). 242021/07/08 20:26:17 size mismatch for convs.1.conv.modulation.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]). 252021/07/08 20:26:17 size mismatch for convs.2.conv.modulation.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]). 262021/07/08 20:26:17 size mismatch for convs.1.conv.modulation.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]). 272021/07/08 20:26:17 size mismatch for convs.2.conv.weight: copying a param with shape torch.Size([1, 256, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 128, 256, 3, 3]). 282021/07/08 20:26:17 size mismatch for convs.3.conv.weight: copying a param with shape torch.Size([1, 256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 128, 128, 3, 3]). 292021/07/08 20:26:17 size mismatch for convs.2.activate.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]). 302021/07/08 20:26:17 size mismatch for convs.3.conv.modulation.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]). 312021/07/08 20:26:17 size mismatch for convs.4.conv.modulation.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([128, 512]). 322021/07/08 20:26:17 size mismatch for convs.5.conv.weight: copying a param with shape torch.Size([1, 128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 64, 64, 3, 3]). 332021/07/08 20:26:17 size mismatch for convs.5.activate.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]). 342021/07/08 20:26:17 size mismatch for convs.3.activate.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]). 352021/07/08 20:26:17 size mismatch for to_rgbs.0.conv.weight: copying a param with shape torch.Size([1, 3, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 256, 1, 1]). 362021/07/08 20:26:17 size mismatch for convs.5.conv.modulation.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([64, 512]). 372021/07/08 20:26:17 size mismatch for convs.4.conv.modulation.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]). 382021/07/08 20:26:17 size mismatch for to_rgbs.1.conv.modulation.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]). 392021/07/08 20:26:17 size mismatch for to_rgbs.0.conv.modulation.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]). 402021/07/08 20:26:17 size mismatch for to_rgbs.2.conv.modulation.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]). 412021/07/08 20:26:17 size mismatch for to_rgbs.1.conv.weight: copying a param with shape torch.Size([1, 3, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 128, 1, 1]). 422021/07/08 20:26:17 size mismatch for to_rgbs.2.conv.weight: copying a param with shape torch.Size([1, 3, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 64, 1, 1]). 432021/07/08 20:26:17 size mismatch for convs.2.conv.modulation.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]). 442021/07/08 20:26:17 category=DeprecationWarning, 452021/07/08 20:26:17 Missing key(s) in state_dict: "encoder.stem.4.conv2.0.weight", "encoder.stem.4.conv2.1.bias", "encoder.style.2.0.kernel", "encoder.style.2.1.weight", "encoder.style.2.2.bias", "encoder.style.5.weight", "encoder.style.5.bias". 462021/07/08 20:26:17 size mismatch for convs.3.conv.modulation.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([128, 512]). 472021/07/08 20:26:17 size mismatch for convs.4.conv.weight: copying a param with shape torch.Size([1, 128, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 64, 128, 3, 3]). 482021/07/08 20:26:17 size mismatch for convs.1.conv.weight: copying a param with shape torch.Size([1, 512, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 256, 256, 3, 3]). 492021/07/08 20:26:17 size mismatch for convs.1.activate.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]). 502021/07/08 20:26:17 size mismatch for to_rgbs.1.conv.modulation.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([128, 512]). 512021/07/08 20:26:17 size mismatch for to_rgbs.2.conv.modulation.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([64, 512]). 522021/07/08 20:26:17 File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1045, in load_state_dict 532021/07/08 20:26:17 size mismatch for to_rgbs.0.conv.modulation.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]). 542021/07/08 20:26:17 Traceback (most recent call last): 552021/07/08 20:26:17 size mismatch for encoder.style.4.bias: copying a param with shape torch.Size([8]) from checkpoint, the shape in current model is torch.Size([512]). 562021/07/08 20:26:17 size mismatch for convs.4.activate.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]). 572021/07/08 20:26:17 size mismatch for convs.5.conv.modulation.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]). how can i fix this? maybe strict=False?

523997931 avatar Jul 08 '21 12:07 523997931

Try setting num_down=4

mchong6 avatar Jul 09 '21 13:07 mchong6