About the weight file of prior
Hello, thank you very much for your work! I would like to use the pre trained prior model for my own research and would like to ask which link is the weight file for you?
I'm busy now. 2 days later I will send you. thx.
Hi @AIDevMonster ,
Could you provide some example code for using the pre-trained vit-b-32 prior model? I have downloaded the checkpoints from https://huggingface.co/nousr/conditioned-prior/tree/main/vit-b-32. However, the readme only shows the usage of loading ViT-L/14 checkpoints. I try me best to adapt this code but it doesn't work.
This is very important for my research. Thanks a lot!
The traceback is below.
RuntimeError: Error(s) in loading state_dict for DiffusionPrior: Missing key(s) in state_dict: "noise_scheduler.betas", "noise_scheduler.alphas_cumprod", "noise_scheduler.alphas_cumprod_prev", "noise_scheduler.sqrt_alphas_cumprod", "noise_scheduler.sqrt_one_minus_alphas_cumprod", "noise_scheduler.log_one_minus_alphas_cumprod", "noise_scheduler.sqrt_recip_alphas_cumprod", "noise_scheduler.sqrt_recipm1_alphas_cumprod", "noise_scheduler.posterior_variance", "noise_scheduler.posterior_log_variance_clipped", "noise_scheduler.posterior_mean_coef1", "noise_scheduler.posterior_mean_coef2", "noise_scheduler.p2_loss_weight", "net.null_text_encodings", "net.null_text_embeds", "net.null_image_embed", "net.causal_transformer.layers.0.0.norm.g", "net.causal_transformer.layers.0.0.to_out.1.g", "net.causal_transformer.layers.0.1.0.g", "net.causal_transformer.layers.0.1.3.g", "net.causal_transformer.layers.1.0.norm.g", "net.causal_transformer.layers.1.0.to_out.1.g", "net.causal_transformer.layers.1.1.0.g", "net.causal_transformer.layers.1.1.3.g", "net.causal_transformer.layers.2.0.norm.g", "net.causal_transformer.layers.2.0.to_out.1.g", "net.causal_transformer.layers.2.1.0.g", "net.causal_transformer.layers.2.1.3.g", "net.causal_transformer.layers.3.0.norm.g", "net.causal_transformer.layers.3.0.to_out.1.g", "net.causal_transformer.layers.3.1.0.g", "net.causal_transformer.layers.3.1.3.g", "net.causal_transformer.layers.4.0.norm.g", "net.causal_transformer.layers.4.0.to_out.1.g", "net.causal_transformer.layers.4.1.0.g", "net.causal_transformer.layers.4.1.3.g", "net.causal_transformer.layers.5.0.norm.g", "net.causal_transformer.layers.5.0.to_out.1.g", "net.causal_transformer.layers.5.1.0.g", "net.causal_transformer.layers.5.1.3.g", "net.causal_transformer.layers.6.0.norm.g", "net.causal_transformer.layers.6.0.to_out.1.g", "net.causal_transformer.layers.6.1.0.g", "net.causal_transformer.layers.6.1.3.g", "net.causal_transformer.layers.7.0.norm.g", "net.causal_transformer.layers.7.0.to_out.1.g", "net.causal_transformer.layers.7.1.0.g", "net.causal_transformer.layers.7.1.3.g", "net.causal_transformer.layers.8.0.norm.g", "net.causal_transformer.layers.8.0.to_out.1.g", "net.causal_transformer.layers.8.1.0.g", "net.causal_transformer.layers.8.1.3.g", "net.causal_transformer.layers.9.0.norm.g", "net.causal_transformer.layers.9.0.to_out.1.g", "net.causal_transformer.layers.9.1.0.g", "net.causal_transformer.layers.9.1.3.g", "net.causal_transformer.layers.10.0.norm.g", "net.causal_transformer.layers.10.0.to_out.1.g", "net.causal_transformer.layers.10.1.0.g", "net.causal_transformer.layers.10.1.3.g", "net.causal_transformer.layers.11.0.norm.g", "net.causal_transformer.layers.11.0.to_out.1.g", "net.causal_transformer.layers.11.1.0.g", "net.causal_transformer.layers.11.1.3.g", "net.causal_transformer.norm.g". Unexpected key(s) in state_dict: "betas", "alphas_cumprod", "alphas_cumprod_prev", "sqrt_alphas_cumprod", "sqrt_one_minus_alphas_cumprod", "log_one_minus_alphas_cumprod", "sqrt_recip_alphas_cumprod", "sqrt_recipm1_alphas_cumprod", "posterior_variance", "posterior_log_variance_clipped", "posterior_mean_coef1", "posterior_mean_coef2", "net.causal_transformer.layers.0.0.norm.gamma", "net.causal_transformer.layers.0.0.norm.beta", "net.causal_transformer.layers.0.0.to_out.1.gamma", "net.causal_transformer.layers.0.0.to_out.1.beta", "net.causal_transformer.layers.0.1.0.gamma", "net.causal_transformer.layers.0.1.0.beta", "net.causal_transformer.layers.0.1.3.gamma", "net.causal_transformer.layers.0.1.3.beta", "net.causal_transformer.layers.1.0.norm.gamma", "net.causal_transformer.layers.1.0.norm.beta", "net.causal_transformer.layers.1.0.to_out.1.gamma", "net.causal_transformer.layers.1.0.to_out.1.beta", "net.causal_transformer.layers.1.1.0.gamma", "net.causal_transformer.layers.1.1.0.beta", "net.causal_transformer.layers.1.1.3.gamma", "net.causal_transformer.layers.1.1.3.beta", "net.causal_transformer.layers.2.0.norm.gamma", "net.causal_transformer.layers.2.0.norm.beta", "net.causal_transformer.layers.2.0.to_out.1.gamma", "net.causal_transformer.layers.2.0.to_out.1.beta", "net.causal_transformer.layers.2.1.0.gamma", "net.causal_transformer.layers.2.1.0.beta", "net.causal_transformer.layers.2.1.3.gamma", "net.causal_transformer.layers.2.1.3.beta", "net.causal_transformer.layers.3.0.norm.gamma", "net.causal_transformer.layers.3.0.norm.beta", "net.causal_transformer.layers.3.0.to_out.1.gamma", "net.causal_transformer.layers.3.0.to_out.1.beta", "net.causal_transformer.layers.3.1.0.gamma", "net.causal_transformer.layers.3.1.0.beta", "net.causal_transformer.layers.3.1.3.gamma", "net.causal_transformer.layers.3.1.3.beta", "net.causal_transformer.layers.4.0.norm.gamma", "net.causal_transformer.layers.4.0.norm.beta", "net.causal_transformer.layers.4.0.to_out.1.gamma", "net.causal_transformer.layers.4.0.to_out.1.beta", "net.causal_transformer.layers.4.1.0.gamma", "net.causal_transformer.layers.4.1.0.beta", "net.causal_transformer.layers.4.1.3.gamma", "net.causal_transformer.layers.4.1.3.beta", "net.causal_transformer.layers.5.0.norm.gamma", "net.causal_transformer.layers.5.0.norm.beta", "net.causal_transformer.layers.5.0.to_out.1.gamma", "net.causal_transformer.layers.5.0.to_out.1.beta", "net.causal_transformer.layers.5.1.0.gamma", "net.causal_transformer.layers.5.1.0.beta", "net.causal_transformer.layers.5.1.3.gamma", "net.causal_transformer.layers.5.1.3.beta", "net.causal_transformer.layers.6.0.norm.gamma", "net.causal_transformer.layers.6.0.norm.beta", "net.causal_transformer.layers.6.0.to_out.1.gamma", "net.causal_transformer.layers.6.0.to_out.1.beta", "net.causal_transformer.layers.6.1.0.gamma", "net.causal_transformer.layers.6.1.0.beta", "net.causal_transformer.layers.6.1.3.gamma", "net.causal_transformer.layers.6.1.3.beta", "net.causal_transformer.layers.7.0.norm.gamma", "net.causal_transformer.layers.7.0.norm.beta", "net.causal_transformer.layers.7.0.to_out.1.gamma", "net.causal_transformer.layers.7.0.to_out.1.beta", "net.causal_transformer.layers.7.1.0.gamma", "net.causal_transformer.layers.7.1.0.beta", "net.causal_transformer.layers.7.1.3.gamma", "net.causal_transformer.layers.7.1.3.beta", "net.causal_transformer.layers.8.0.norm.gamma", "net.causal_transformer.layers.8.0.norm.beta", "net.causal_transformer.layers.8.0.to_out.1.gamma", "net.causal_transformer.layers.8.0.to_out.1.beta", "net.causal_transformer.layers.8.1.0.gamma", "net.causal_transformer.layers.8.1.0.beta", "net.causal_transformer.layers.8.1.3.gamma", "net.causal_transformer.layers.8.1.3.beta", "net.causal_transformer.layers.9.0.norm.gamma", "net.causal_transformer.layers.9.0.norm.beta", "net.causal_transformer.layers.9.0.to_out.1.gamma", "net.causal_transformer.layers.9.0.to_out.1.beta", "net.causal_transformer.layers.9.1.0.gamma", "net.causal_transformer.layers.9.1.0.beta", "net.causal_transformer.layers.9.1.3.gamma", "net.causal_transformer.layers.9.1.3.beta", "net.causal_transformer.layers.10.0.norm.gamma", "net.causal_transformer.layers.10.0.norm.beta", "net.causal_transformer.layers.10.0.to_out.1.gamma", "net.causal_transformer.layers.10.0.to_out.1.beta", "net.causal_transformer.layers.10.1.0.gamma", "net.causal_transformer.layers.10.1.0.beta", "net.causal_transformer.layers.10.1.3.gamma", "net.causal_transformer.layers.10.1.3.beta", "net.causal_transformer.layers.11.0.norm.gamma", "net.causal_transformer.layers.11.0.norm.beta", "net.causal_transformer.layers.11.0.to_out.1.gamma", "net.causal_transformer.layers.11.0.to_out.1.beta", "net.causal_transformer.layers.11.1.0.gamma", "net.causal_transformer.layers.11.1.0.beta", "net.causal_transformer.layers.11.1.3.gamma", "net.causal_transformer.layers.11.1.3.beta", "net.causal_transformer.norm.gamma", "net.causal_transformer.norm.beta". size mismatch for net.to_time_embeds.0.weight: copying a param with shape torch.Size([100, 512]) from checkpoint, the shape in current model is torch.Size([1000, 512]). size mismatch for net.causal_transformer.rel_pos_bias.relative_attention_bias.weight: copying a param with shape torch.Size([32, 16]) from checkpoint, the shape in current model is torch.Size([32, 8]). size mismatch for net.causal_transformer.layers.0.0.to_q.weight: copying a param with shape torch.Size([1024, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.0.0.to_out.0.weight: copying a param with shape torch.Size([512, 1024]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.1.0.to_q.weight: copying a param with shape torch.Size([1024, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.1.0.to_out.0.weight: copying a param with shape torch.Size([512, 1024]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.2.0.to_q.weight: copying a param with shape torch.Size([1024, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.2.0.to_out.0.weight: copying a param with shape torch.Size([512, 1024]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.3.0.to_q.weight: copying a param with shape torch.Size([1024, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.3.0.to_out.0.weight: copying a param with shape torch.Size([512, 1024]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.4.0.to_q.weight: copying a param with shape torch.Size([1024, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.4.0.to_out.0.weight: copying a param with shape torch.Size([512, 1024]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.5.0.to_q.weight: copying a param with shape torch.Size([1024, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.5.0.to_out.0.weight: copying a param with shape torch.Size([512, 1024]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.6.0.to_q.weight: copying a param with shape torch.Size([1024, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.6.0.to_out.0.weight: copying a param with shape torch.Size([512, 1024]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.7.0.to_q.weight: copying a param with shape torch.Size([1024, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.7.0.to_out.0.weight: copying a param with shape torch.Size([512, 1024]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.8.0.to_q.weight: copying a param with shape torch.Size([1024, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.8.0.to_out.0.weight: copying a param with shape torch.Size([512, 1024]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.9.0.to_q.weight: copying a param with shape torch.Size([1024, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.9.0.to_out.0.weight: copying a param with shape torch.Size([512, 1024]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.10.0.to_q.weight: copying a param with shape torch.Size([1024, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.10.0.to_out.0.weight: copying a param with shape torch.Size([512, 1024]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.11.0.to_q.weight: copying a param with shape torch.Size([1024, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.11.0.to_out.0.weight: copying a param with shape torch.Size([512, 1024]) from checkpoint, the shape in current model is torch.Size([512, 512]).