DALLE2-pytorch icon indicating copy to clipboard operation
DALLE2-pytorch copied to clipboard

How to eval diffusion prior?

Open jiamingzhang94 opened this issue 7 months ago • 0 comments

I have downloaded the checkpoints from https://huggingface.co/nousr/conditioned-prior/tree/main/vit-b-32. However, the readme only shows the usage of loading ViT-L/14 checkpoints, how to load vit-b-32?

I try me best to adapt this code but it doesn't work.

`path = torch.load('C:/Users/admin/Downloads/ema472M.pth', map_location='cpu')

def load_diffusion_model(dprior_path):

prior_network = DiffusionPriorNetwork(
    dim=512,  # Adjusted for ViT-B/32
    depth=12,  # Adjusted for ViT-B/32
    dim_head=64,
    heads=8,  # Adjusted for ViT-B/32
    normformer=True,
    attn_dropout=5e-2,
    ff_dropout=5e-2,
    num_time_embeds=1,
    num_image_embeds=1,
    num_text_embeds=1,
    num_timesteps=1000,
    ff_mult=4
)

diffusion_prior = DiffusionPrior(
    net=prior_network,
    clip=OpenAIClipAdapter("ViT-B/32"),
    image_embed_dim=512,  # Adjusted for ViT-B/32
    timesteps=1000,
    cond_drop_prob=0.1,
    loss_type="l2",
    condition_on_text_encodings=True,
)

trainer = DiffusionPriorTrainer(
    diffusion_prior=diffusion_prior,
    lr=1.1e-4,
    wd=6.02e-2,
    max_grad_norm=0.5,
    amp=False,
    group_wd_params=True,
    use_ema=True,
    device=device,
    accelerator=None,
)

trainer.load(dprior_path)

return trainer

a = load_diffusion_model(path)`

Traceback is below:

RuntimeError: Error(s) in loading state_dict for DiffusionPrior: Missing key(s) in state_dict: "noise_scheduler.betas", "noise_scheduler.alphas_cumprod", "noise_scheduler.alphas_cumprod_prev", "noise_scheduler.sqrt_alphas_cumprod", "noise_scheduler.sqrt_one_minus_alphas_cumprod", "noise_scheduler.log_one_minus_alphas_cumprod", "noise_scheduler.sqrt_recip_alphas_cumprod", "noise_scheduler.sqrt_recipm1_alphas_cumprod", "noise_scheduler.posterior_variance", "noise_scheduler.posterior_log_variance_clipped", "noise_scheduler.posterior_mean_coef1", "noise_scheduler.posterior_mean_coef2", "noise_scheduler.p2_loss_weight", "net.null_text_encodings", "net.null_text_embeds", "net.null_image_embed", "net.causal_transformer.layers.0.0.norm.g", "net.causal_transformer.layers.0.0.to_out.1.g", "net.causal_transformer.layers.0.1.0.g", "net.causal_transformer.layers.0.1.3.g", "net.causal_transformer.layers.1.0.norm.g", "net.causal_transformer.layers.1.0.to_out.1.g", "net.causal_transformer.layers.1.1.0.g", "net.causal_transformer.layers.1.1.3.g", "net.causal_transformer.layers.2.0.norm.g", "net.causal_transformer.layers.2.0.to_out.1.g", "net.causal_transformer.layers.2.1.0.g", "net.causal_transformer.layers.2.1.3.g", "net.causal_transformer.layers.3.0.norm.g", "net.causal_transformer.layers.3.0.to_out.1.g", "net.causal_transformer.layers.3.1.0.g", "net.causal_transformer.layers.3.1.3.g", "net.causal_transformer.layers.4.0.norm.g", "net.causal_transformer.layers.4.0.to_out.1.g", "net.causal_transformer.layers.4.1.0.g", "net.causal_transformer.layers.4.1.3.g", "net.causal_transformer.layers.5.0.norm.g", "net.causal_transformer.layers.5.0.to_out.1.g", "net.causal_transformer.layers.5.1.0.g", "net.causal_transformer.layers.5.1.3.g", "net.causal_transformer.layers.6.0.norm.g", "net.causal_transformer.layers.6.0.to_out.1.g", "net.causal_transformer.layers.6.1.0.g", "net.causal_transformer.layers.6.1.3.g", "net.causal_transformer.layers.7.0.norm.g", "net.causal_transformer.layers.7.0.to_out.1.g", "net.causal_transformer.layers.7.1.0.g", "net.causal_transformer.layers.7.1.3.g", "net.causal_transformer.layers.8.0.norm.g", "net.causal_transformer.layers.8.0.to_out.1.g", "net.causal_transformer.layers.8.1.0.g", "net.causal_transformer.layers.8.1.3.g", "net.causal_transformer.layers.9.0.norm.g", "net.causal_transformer.layers.9.0.to_out.1.g", "net.causal_transformer.layers.9.1.0.g", "net.causal_transformer.layers.9.1.3.g", "net.causal_transformer.layers.10.0.norm.g", "net.causal_transformer.layers.10.0.to_out.1.g", "net.causal_transformer.layers.10.1.0.g", "net.causal_transformer.layers.10.1.3.g", "net.causal_transformer.layers.11.0.norm.g", "net.causal_transformer.layers.11.0.to_out.1.g", "net.causal_transformer.layers.11.1.0.g", "net.causal_transformer.layers.11.1.3.g", "net.causal_transformer.norm.g". Unexpected key(s) in state_dict: "betas", "alphas_cumprod", "alphas_cumprod_prev", "sqrt_alphas_cumprod", "sqrt_one_minus_alphas_cumprod", "log_one_minus_alphas_cumprod", "sqrt_recip_alphas_cumprod", "sqrt_recipm1_alphas_cumprod", "posterior_variance", "posterior_log_variance_clipped", "posterior_mean_coef1", "posterior_mean_coef2", "net.causal_transformer.layers.0.0.norm.gamma", "net.causal_transformer.layers.0.0.norm.beta", "net.causal_transformer.layers.0.0.to_out.1.gamma", "net.causal_transformer.layers.0.0.to_out.1.beta", "net.causal_transformer.layers.0.1.0.gamma", "net.causal_transformer.layers.0.1.0.beta", "net.causal_transformer.layers.0.1.3.gamma", "net.causal_transformer.layers.0.1.3.beta", "net.causal_transformer.layers.1.0.norm.gamma", "net.causal_transformer.layers.1.0.norm.beta", "net.causal_transformer.layers.1.0.to_out.1.gamma", "net.causal_transformer.layers.1.0.to_out.1.beta", "net.causal_transformer.layers.1.1.0.gamma", "net.causal_transformer.layers.1.1.0.beta", "net.causal_transformer.layers.1.1.3.gamma", "net.causal_transformer.layers.1.1.3.beta", "net.causal_transformer.layers.2.0.norm.gamma", "net.causal_transformer.layers.2.0.norm.beta", "net.causal_transformer.layers.2.0.to_out.1.gamma", "net.causal_transformer.layers.2.0.to_out.1.beta", "net.causal_transformer.layers.2.1.0.gamma", "net.causal_transformer.layers.2.1.0.beta", "net.causal_transformer.layers.2.1.3.gamma", "net.causal_transformer.layers.2.1.3.beta", "net.causal_transformer.layers.3.0.norm.gamma", "net.causal_transformer.layers.3.0.norm.beta", "net.causal_transformer.layers.3.0.to_out.1.gamma", "net.causal_transformer.layers.3.0.to_out.1.beta", "net.causal_transformer.layers.3.1.0.gamma", "net.causal_transformer.layers.3.1.0.beta", "net.causal_transformer.layers.3.1.3.gamma", "net.causal_transformer.layers.3.1.3.beta", "net.causal_transformer.layers.4.0.norm.gamma", "net.causal_transformer.layers.4.0.norm.beta", "net.causal_transformer.layers.4.0.to_out.1.gamma", "net.causal_transformer.layers.4.0.to_out.1.beta", "net.causal_transformer.layers.4.1.0.gamma", "net.causal_transformer.layers.4.1.0.beta", "net.causal_transformer.layers.4.1.3.gamma", "net.causal_transformer.layers.4.1.3.beta", "net.causal_transformer.layers.5.0.norm.gamma", "net.causal_transformer.layers.5.0.norm.beta", "net.causal_transformer.layers.5.0.to_out.1.gamma", "net.causal_transformer.layers.5.0.to_out.1.beta", "net.causal_transformer.layers.5.1.0.gamma", "net.causal_transformer.layers.5.1.0.beta", "net.causal_transformer.layers.5.1.3.gamma", "net.causal_transformer.layers.5.1.3.beta", "net.causal_transformer.layers.6.0.norm.gamma", "net.causal_transformer.layers.6.0.norm.beta", "net.causal_transformer.layers.6.0.to_out.1.gamma", "net.causal_transformer.layers.6.0.to_out.1.beta", "net.causal_transformer.layers.6.1.0.gamma", "net.causal_transformer.layers.6.1.0.beta", "net.causal_transformer.layers.6.1.3.gamma", "net.causal_transformer.layers.6.1.3.beta", "net.causal_transformer.layers.7.0.norm.gamma", "net.causal_transformer.layers.7.0.norm.beta", "net.causal_transformer.layers.7.0.to_out.1.gamma", "net.causal_transformer.layers.7.0.to_out.1.beta", "net.causal_transformer.layers.7.1.0.gamma", "net.causal_transformer.layers.7.1.0.beta", "net.causal_transformer.layers.7.1.3.gamma", "net.causal_transformer.layers.7.1.3.beta", "net.causal_transformer.layers.8.0.norm.gamma", "net.causal_transformer.layers.8.0.norm.beta", "net.causal_transformer.layers.8.0.to_out.1.gamma", "net.causal_transformer.layers.8.0.to_out.1.beta", "net.causal_transformer.layers.8.1.0.gamma", "net.causal_transformer.layers.8.1.0.beta", "net.causal_transformer.layers.8.1.3.gamma", "net.causal_transformer.layers.8.1.3.beta", "net.causal_transformer.layers.9.0.norm.gamma", "net.causal_transformer.layers.9.0.norm.beta", "net.causal_transformer.layers.9.0.to_out.1.gamma", "net.causal_transformer.layers.9.0.to_out.1.beta", "net.causal_transformer.layers.9.1.0.gamma", "net.causal_transformer.layers.9.1.0.beta", "net.causal_transformer.layers.9.1.3.gamma", "net.causal_transformer.layers.9.1.3.beta", "net.causal_transformer.layers.10.0.norm.gamma", "net.causal_transformer.layers.10.0.norm.beta", "net.causal_transformer.layers.10.0.to_out.1.gamma", "net.causal_transformer.layers.10.0.to_out.1.beta", "net.causal_transformer.layers.10.1.0.gamma", "net.causal_transformer.layers.10.1.0.beta", "net.causal_transformer.layers.10.1.3.gamma", "net.causal_transformer.layers.10.1.3.beta", "net.causal_transformer.layers.11.0.norm.gamma", "net.causal_transformer.layers.11.0.norm.beta", "net.causal_transformer.layers.11.0.to_out.1.gamma", "net.causal_transformer.layers.11.0.to_out.1.beta", "net.causal_transformer.layers.11.1.0.gamma", "net.causal_transformer.layers.11.1.0.beta", "net.causal_transformer.layers.11.1.3.gamma", "net.causal_transformer.layers.11.1.3.beta", "net.causal_transformer.norm.gamma", "net.causal_transformer.norm.beta". size mismatch for net.to_time_embeds.0.weight: copying a param with shape torch.Size([100, 512]) from checkpoint, the shape in current model is torch.Size([1000, 512]). size mismatch for net.causal_transformer.rel_pos_bias.relative_attention_bias.weight: copying a param with shape torch.Size([32, 16]) from checkpoint, the shape in current model is torch.Size([32, 8]). size mismatch for net.causal_transformer.layers.0.0.to_q.weight: copying a param with shape torch.Size([1024, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.0.0.to_out.0.weight: copying a param with shape torch.Size([512, 1024]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.1.0.to_q.weight: copying a param with shape torch.Size([1024, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.1.0.to_out.0.weight: copying a param with shape torch.Size([512, 1024]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.2.0.to_q.weight: copying a param with shape torch.Size([1024, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.2.0.to_out.0.weight: copying a param with shape torch.Size([512, 1024]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.3.0.to_q.weight: copying a param with shape torch.Size([1024, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.3.0.to_out.0.weight: copying a param with shape torch.Size([512, 1024]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.4.0.to_q.weight: copying a param with shape torch.Size([1024, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.4.0.to_out.0.weight: copying a param with shape torch.Size([512, 1024]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.5.0.to_q.weight: copying a param with shape torch.Size([1024, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.5.0.to_out.0.weight: copying a param with shape torch.Size([512, 1024]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.6.0.to_q.weight: copying a param with shape torch.Size([1024, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.6.0.to_out.0.weight: copying a param with shape torch.Size([512, 1024]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.7.0.to_q.weight: copying a param with shape torch.Size([1024, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.7.0.to_out.0.weight: copying a param with shape torch.Size([512, 1024]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.8.0.to_q.weight: copying a param with shape torch.Size([1024, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.8.0.to_out.0.weight: copying a param with shape torch.Size([512, 1024]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.9.0.to_q.weight: copying a param with shape torch.Size([1024, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.9.0.to_out.0.weight: copying a param with shape torch.Size([512, 1024]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.10.0.to_q.weight: copying a param with shape torch.Size([1024, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.10.0.to_out.0.weight: copying a param with shape torch.Size([512, 1024]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.11.0.to_q.weight: copying a param with shape torch.Size([1024, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for net.causal_transformer.layers.11.0.to_out.0.weight: copying a param with shape torch.Size([512, 1024]) from checkpoint, the shape in current model is torch.Size([512, 512]).

jiamingzhang94 avatar Jul 17 '24 06:07 jiamingzhang94