guided-diffusion icon indicating copy to clipboard operation
guided-diffusion copied to clipboard

Sampling at 64x64 - Missing key(s) in state_dict / size mismatch - segfault

Open acardara opened this issue 3 years ago • 18 comments

I want to sample images from the pretrained 64x64_diffusion model but am hitting a segfault with the suggested run configuration. I've downloaded the 64x64 checkpoints to a models folder and am running with the following flags.

!SAMPLE_FLAGS="--batch_size 4 --num_samples 100 --timestep_respacing 250"

!MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond True --diffusion_steps 1000 --dropout 0.1 --image_size 64 --learn_sigma True --noise_schedule cosine --num_channels 192 --num_head_channels 64 --num_res_blocks 3 --resblock_updown True --use_new_attention_order True --use_fp16 True --use_scale_shift_norm True"

!python image_sample.py $MODEL_FLAGS --model_path models/64x64_diffusion.pt $SAMPLE_FLAGS

At runtime, I get a slew of warnings about missing and unused keys before the code crashes via segfault:

Missing key(s) in state_dict: "input_blocks.3.0.op.weight", "input_blocks.3.0.op.bias", "input_blocks.4.0.skip_connection.weight", ..., "output_blocks.8.1.conv.bias".

Unexpected key(s) in state_dict: "label_emb.weight", "input_blocks.12.0.in_layers.0.weight", "input_blocks.12.0.in_layers.0.bias", ..., "output_blocks.11.2.out_layers.3.bias".

size mismatch for time_embed.0.weight: copying a param with shape torch.Size([768, 192]) from checkpoint, the shape in current model is torch.Size([512, 128]). ... size mismatch for out.2.bias: copying a param with shape torch.Size([6]) from checkpoint, the shape in current model is torch.Size([3]).

acardara avatar Jul 27 '21 01:07 acardara

I matched the model architecture as suggested in #7, which removed the mismatch warnings, but the missing and unexpected key warnings are still there. I am still getting a segfault.

acardara avatar Jul 27 '21 02:07 acardara

Hi, I get a runtime error with the same message:

RuntimeError: Error(s) in loading state_dict for SuperResModel: Missing key(s) in state_dict: "input_blocks.3.0.op.weight", "input_blocks.3.0.op.bias", "input_blocks.6.0.op.weight", "input_blocks.6.0.op.bias", "input_blocks.9.0.op.weight", "input_blocks.9.0.op.bias", "input_blocks.12.0.op.weight", "input_blocks.12.0.op.bias", "input_blocks.15.0.op.weight", "input_blocks.15.0.op.bias", "output_blocks.2.2.conv.weight", "output_blocks.2.2.conv.bias", "output_blocks.5.2.conv.weight", "output_blocks.5.2.conv.bias", "output_blocks.8.1.conv.weight", "output_blocks.8.1.conv.bias", "output_blocks.11.1.conv.weight", "output_blocks.11.1.conv.bias", "output_blocks.14.1.conv.weight", "output_blocks.14.1.conv.bias".

Did you manage to solve it?

lmvgjp avatar Dec 15 '21 08:12 lmvgjp

No, I didn't find a solution.

acardara avatar Dec 15 '21 17:12 acardara

ok, thanks for answering!

lmvgjp avatar Dec 16 '21 06:12 lmvgjp

Hi, I encountered the same problem here. I guess that the published model has a bit different architecture than the one written in the code. Can you please see if they match?

inbarhub avatar Jan 17 '22 14:01 inbarhub

I solved it by using 'restrict=False' flag when loading the model but the results I get are really poor. I guess this is because the model was not loaded well.

inbarhub avatar Jan 17 '22 14:01 inbarhub

Any news on this? I'm hitting the same issue.

EDIT: I was defining the environment variables the wrong way (new to Jupyter 😅 )

DiogoNeves avatar May 27 '22 16:05 DiogoNeves

@acardara I'm still having other issues but, I think this might help you. From your message you seem to be running this within a Jupyter Notebook. You're currently defining the environment variables using ! which shouldn't persist to other commands.

Try using %env like:

%env SAMPLE_FLAGS=...
!python image_sample.py ... ${SAMPLE_FLAGS}

I'm new to Jupyter and was running into a similar issue. My understanding is that !SAMPLE_FLAGS would only work if you run the python script in the same line, similar to inline setting a var in bash.
I haven't tried but !SAMPLE_FLAGS="..." python ... $SAMPLE_FLAGS" should work if I'm right.

DiogoNeves avatar May 27 '22 16:05 DiogoNeves

I solved it by using 'restrict=False' flag when loading the model but the results I get are really poor. I guess this is because the model was not loaded well.

Hi @inbarhub I tried using restrict=False here :

model.load_state_dict( dist_util.load_state_dict(args.model_path, restrict=False) ) but it did not work

shahdghorsi avatar Jul 31 '22 18:07 shahdghorsi

I solve the same problem! model.load_state_dict( dist_util.load_state_dict(args.model_path, map_location="cpu"), strict=False )

wangqiang9 avatar Sep 11 '22 05:09 wangqiang9