guided-diffusion
guided-diffusion copied to clipboard
Sampling at 64x64 - Missing key(s) in state_dict / size mismatch - segfault
I want to sample images from the pretrained 64x64_diffusion model but am hitting a segfault with the suggested run configuration. I've downloaded the 64x64 checkpoints to a models
folder and am running with the following flags.
!SAMPLE_FLAGS="--batch_size 4 --num_samples 100 --timestep_respacing 250"
!MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond True --diffusion_steps 1000 --dropout 0.1 --image_size 64 --learn_sigma True --noise_schedule cosine --num_channels 192 --num_head_channels 64 --num_res_blocks 3 --resblock_updown True --use_new_attention_order True --use_fp16 True --use_scale_shift_norm True"
!python image_sample.py $MODEL_FLAGS --model_path models/64x64_diffusion.pt $SAMPLE_FLAGS
At runtime, I get a slew of warnings about missing and unused keys before the code crashes via segfault:
Missing key(s) in state_dict: "input_blocks.3.0.op.weight", "input_blocks.3.0.op.bias", "input_blocks.4.0.skip_connection.weight", ..., "output_blocks.8.1.conv.bias".
Unexpected key(s) in state_dict: "label_emb.weight", "input_blocks.12.0.in_layers.0.weight", "input_blocks.12.0.in_layers.0.bias", ..., "output_blocks.11.2.out_layers.3.bias".
size mismatch for time_embed.0.weight: copying a param with shape torch.Size([768, 192]) from checkpoint, the shape in current model is torch.Size([512, 128]). ... size mismatch for out.2.bias: copying a param with shape torch.Size([6]) from checkpoint, the shape in current model is torch.Size([3]).
I matched the model architecture as suggested in #7, which removed the mismatch warnings, but the missing and unexpected key warnings are still there. I am still getting a segfault.
Hi, I get a runtime error with the same message:
RuntimeError: Error(s) in loading state_dict for SuperResModel: Missing key(s) in state_dict: "input_blocks.3.0.op.weight", "input_blocks.3.0.op.bias", "input_blocks.6.0.op.weight", "input_blocks.6.0.op.bias", "input_blocks.9.0.op.weight", "input_blocks.9.0.op.bias", "input_blocks.12.0.op.weight", "input_blocks.12.0.op.bias", "input_blocks.15.0.op.weight", "input_blocks.15.0.op.bias", "output_blocks.2.2.conv.weight", "output_blocks.2.2.conv.bias", "output_blocks.5.2.conv.weight", "output_blocks.5.2.conv.bias", "output_blocks.8.1.conv.weight", "output_blocks.8.1.conv.bias", "output_blocks.11.1.conv.weight", "output_blocks.11.1.conv.bias", "output_blocks.14.1.conv.weight", "output_blocks.14.1.conv.bias".
Did you manage to solve it?
No, I didn't find a solution.
ok, thanks for answering!
Hi, I encountered the same problem here. I guess that the published model has a bit different architecture than the one written in the code. Can you please see if they match?
I solved it by using 'restrict=False' flag when loading the model but the results I get are really poor. I guess this is because the model was not loaded well.
Any news on this? I'm hitting the same issue.
EDIT: I was defining the environment variables the wrong way (new to Jupyter 😅 )
@acardara I'm still having other issues but, I think this might help you. From your message you seem to be running this within a Jupyter Notebook. You're currently defining the environment variables using !
which shouldn't persist to other commands.
Try using %env
like:
%env SAMPLE_FLAGS=...
!python image_sample.py ... ${SAMPLE_FLAGS}
I'm new to Jupyter and was running into a similar issue. My understanding is that !SAMPLE_FLAGS
would only work if you run the python script in the same line, similar to inline setting a var in bash.
I haven't tried but !SAMPLE_FLAGS="..." python ... $SAMPLE_FLAGS"
should work if I'm right.
I solved it by using 'restrict=False' flag when loading the model but the results I get are really poor. I guess this is because the model was not loaded well.
Hi @inbarhub I tried using restrict=False here :
model.load_state_dict( dist_util.load_state_dict(args.model_path, restrict=False) ) but it did not work
I solve the same problem!
model.load_state_dict( dist_util.load_state_dict(args.model_path, map_location="cpu"), strict=False )