stylegan2-ada-pytorch icon indicating copy to clipboard operation
stylegan2-ada-pytorch copied to clipboard

Problems in converting to .pt

Open ucalyptus2 opened this issue 4 years ago • 1 comments
trafficstars

  • I trained a custom sg2-ada model from the official pytorch implementation and now it is in .pkl (Custom Dataset and 256px)
  • I used export_weights.py to convert them. It converted safely.
  • On generate.py I get various errors.

The Error

Traceback (most recent call last):
  File "generate.py", line 76, in <module>
    g_ema.load_state_dict(checkpoint["g_ema"])
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1407, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for Generator:
	Missing key(s) in state_dict: "style.3.weight", "style.3.bias", "style.4.weight", "style.4.bias", "style.5.weight", "style.5.bias", "style.6.weight", "style.6.bias", "style.7.weight", "style.7.bias", "style.8.weight", "style.8.bias". 
	size mismatch for convs.6.conv.weight: copying a param with shape torch.Size([1, 256, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 512, 512, 3, 3]).
	size mismatch for convs.6.activate.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for convs.7.conv.weight: copying a param with shape torch.Size([1, 256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 512, 512, 3, 3]).
	size mismatch for convs.7.conv.modulation.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for convs.7.conv.modulation.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for convs.7.activate.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for convs.8.conv.weight: copying a param with shape torch.Size([1, 128, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 256, 512, 3, 3]).
	size mismatch for convs.8.conv.modulation.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for convs.8.conv.modulation.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for convs.8.activate.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for convs.9.conv.weight: copying a param with shape torch.Size([1, 128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 256, 256, 3, 3]).
	size mismatch for convs.9.conv.modulation.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]).
	size mismatch for convs.9.conv.modulation.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for convs.9.activate.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for convs.10.conv.weight: copying a param with shape torch.Size([1, 64, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 128, 256, 3, 3]).
	size mismatch for convs.10.conv.modulation.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]).
	size mismatch for convs.10.conv.modulation.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for convs.10.activate.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for convs.11.conv.weight: copying a param with shape torch.Size([1, 64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 128, 128, 3, 3]).
	size mismatch for convs.11.conv.modulation.weight: copying a param with shape torch.Size([64, 512]) from checkpoint, the shape in current model is torch.Size([128, 512]).
	size mismatch for convs.11.conv.modulation.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for convs.11.activate.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for to_rgbs.3.conv.weight: copying a param with shape torch.Size([1, 3, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 512, 1, 1]).
	size mismatch for to_rgbs.3.conv.modulation.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for to_rgbs.3.conv.modulation.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for to_rgbs.4.conv.weight: copying a param with shape torch.Size([1, 3, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 256, 1, 1]).
	size mismatch for to_rgbs.4.conv.modulation.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]).
	size mismatch for to_rgbs.4.conv.modulation.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for to_rgbs.5.conv.weight: copying a param with shape torch.Size([1, 3, 64, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 128, 1, 1]).
	size mismatch for to_rgbs.5.conv.modulation.weight: copying a param with shape torch.Size([64, 512]) from checkpoint, the shape in current model is torch.Size([128, 512]).
	size mismatch for to_rgbs.5.conv.modulation.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).

Reproduce the error

cc: @dvschultz

ucalyptus2 avatar Nov 07 '21 02:11 ucalyptus2

Yes that script assumes you’re converting a 1024 model. There may be a fix for 256 models here: https://github.com/dvschultz/stylegan2-ada-pytorch/issues/6

dvschultz avatar Nov 07 '21 03:11 dvschultz