DreamPose
DreamPose copied to clipboard
dict key error in demo
Hello, When I am running demo, I am getting the following dictionary key error. I am using PyTorch 2.0.1 and cuda 11.7. Is that specific version of PyTorch that I should use? Any ideas how I can resolve this? Thank you
Traceback (most recent call last):
File "test.py", line 87, in <module>
pipe.vae.load_state_dict(new_state_dict)
File "C:\Users\best4\AppData\Roaming\Python\Python37\site-packages\torch\nn\modules\module.py", line 1672, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for AutoencoderKL:
Missing key(s) in state_dict: "encoder.mid_block.attentions.0.to_q.weight", "encoder.mid_block.attentions.0.to_q.bias", "encoder.mid_block.attentions.0.to_k.weight", "encoder.mid_block.attentions.0.to_k.bias", "encoder.mid_block.attentions.0.to_v.weight", "encoder.mid_block.attentions.0.to_v.bias", "encoder.mid_block.attentions.0.to_out.0.weight", "encoder.mid_block.attentions.0.to_out.0.bias", "decoder.mid_block.attentions.0.to_q.weight", "decoder.mid_block.attentions.0.to_q.bias", "decoder.mid_block.attentions.0.to_k.weight", "decoder.mid_block.attentions.0.to_k.bias", "decoder.mid_block.attentions.0.to_v.weight", "decoder.mid_block.attentions.0.to_v.bias", "decoder.mid_block.attentions.0.to_out.0.weight", "decoder.mid_block.attentions.0.to_out.0.bias".
Unexpected key(s) in state_dict: "encoder.mid_block.attentions.0.query.weight", "encoder.mid_block.attentions.0.query.bias", "encoder.mid_block.attentions.0.key.weight", "encoder.mid_block.attentions.0.key.bias", "encoder.mid_block.attentions.0.value.weight", "encoder.mid_block.attentions.0.value.bias", "encoder.mid_block.attentions.0.proj_attn.weight", "encoder.mid_block.attentions.0.proj_attn.bias", "decoder.mid_block.attentions.0.query.weight", "decoder.mid_block.attentions.0.query.bias", "decoder.mid_block.attentions.0.key.weight", "decoder.mid_block.attentions.0.key.bias", "decoder.mid_block.attentions.0.value.weight", "decoder.mid_block.attentions.0.value.bias", "decoder.mid_block.attentions.0.proj_attn.weight", "decoder.mid_block.attentions.0.proj_attn.bias".
have you fixed it yet? I have the same problem.
no , I do not have a fix for this
I have the same problem.
I replaced keywords and it worked.
Same! I have replaced it like this and now it doesn't give me this error anymore:
name = name.replace('query.', 'to_q.')
name = name.replace('key.', 'to_k.')
name = name.replace('value.', 'to_v.')
name = name.replace('proj_attn.', 'to_out.')
name = name.replace('.mid_block.attentions.0.to_out.', '.mid_block.attentions.0.to_out.0.')
@LaiaTarres hi Laia can u plz help me to find, where should replace this keywords (file name + ligne) . and thank you so much
@LaiaTarres Can you send any code that needs to be changed or replaced? I would love to be able to test it because it really amazes me [email protected] Thank you
I fixed this by modifying the line (this is a common issues so all you need to do is toc change the state dict names such that they match your expected state:
for k, v in vae_state_dict.items():
name1 = k.replace('module.', '') #name = k[7:] if k[:7] == 'module' else k
name2 = name1.replace('query', 'to_q') #name = k[7:] if k[:7] == 'module' else k
name3 = name2.replace('key', 'to_k')
name4 = name3.replace('value', 'to_v')
name = name4.replace('proj_attn', 'to_out.0')
new_state_dict[name] = v
pipe.vae.load_state_dict(new_state_dict)
pipe.vae = pipe.vae.cuda()