Open-Sora icon indicating copy to clipboard operation
Open-Sora copied to clipboard

Error on running Inference

Open tgmorais opened this issue 1 year ago • 3 comments

Runing the inerence sample

python sample.py -m "DiT/XL-2" --text "a person is walking on the street" --ckpt /path/to/checkpoint --height 256 --width 256 --fps 10 --sec 5 --disable-cfg

I got the following error:

usage: sample.py [-h] [-m {DiT-XL/2,DiT-XL/4,DiT-XL/8,DiT-L/2,DiT-L/4,DiT-L/8,DiT-B/2,DiT-B/4,DiT-B/8,DiT-S/2,DiT-S/4,DiT-S/8}] [--text TEXT] [--cfg-scale CFG_SCALE] [--num-sampling-steps NUM_SAMPLING_STEPS] [--seed SEED] --ckpt CKPT [-c {raw,vqvae,vae}] [--text_model TEXT_MODEL] [--width WIDTH] [--height HEIGHT] [--fps FPS] [--sec SEC] [--disable-cfg] sample.py: error: argument -m/--model: invalid choice: 'DiT/XL-2' (choose from 'DiT-XL/2', 'DiT-XL/4', 'DiT-XL/8', 'DiT-L/2', 'DiT-L/4', 'DiT-L/8', 'DiT-B/2', 'DiT-B/4', 'DiT-B/8', 'DiT-S/2', 'DiT-S/4', 'DiT-S/8')

Then, I changed to !python sample.py -m "DiT-XL/2" --text "a person is walking on the street" --ckpt pretrained_models/DiT-XL-2-256x256.pt --height 256 --width 256 --fps 10 --sec 5 --disable-cfg

But I got a different error

Traceback (most recent call last): File "/content/Open-Sora/sample.py", line 136, in main(args) File "/content/Open-Sora/sample.py", line 39, in main model.load_state_dict(torch.load(args.ckpt)) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2152, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for DiT: Missing key(s) in state_dict: "video_embedder.proj.weight", "video_embedder.proj.bias", "blocks.0.attn.to_q.weight", "blocks.0.attn.to_q.bias", "blocks.0.attn.to_k.weight", "blocks.0.attn.to_k.bias", "blocks.0.attn.to_v.weight", "blocks.0.attn.to_v.bias", "blocks.0.attn.to_out.0.weight", "blocks.0.attn.to_out.0.bias", "blocks.1.attn.to_q.weight", "blocks.1.attn.to_q.bias", "blocks.1.attn.to_k.weight", "blocks.1.attn.to_k.bias", "blocks.1.attn.to_v.weight", "blocks.1.attn.to_v.bias", "blocks.1.attn.to_out.0.weight", "blocks.1.attn.to_out.0.bias", "blocks.2.attn.to_q.weight", "blocks.2.attn.to_q.bias", "blocks.2.attn.to_k.weight", "blocks.2.attn.to_k.bias", "blocks.2.attn.to_v.weight", "blocks.2.attn.to_v.bias", "blocks.2.attn.to_out.0.weight", "blocks.2.attn.to_out.0.bias", "blocks.3.attn.to_q.weight", "blocks.3.attn.to_q.bias", "blocks.3.attn.to_k.weight", "blocks.3.attn.to_k.bias", "blocks.3.attn.to_v.weight", "blocks.3.attn.to_v.bias", "blocks.3.attn.to_out.0.weight", "blocks.3.attn.to_out.0.bias", "blocks.4.attn.to_q.weight", "blocks.4.attn.to_q.bias", "blocks.4.attn.to_k.weight", "blocks.4.attn.to_k.bias", "blocks.4.attn.to_v.weight", "blocks.4.attn.to_v.bias", "blocks.4.attn.to_out.0.weight", "blocks.4.attn.to_out.0.bias", "blocks.5.attn.to_q.weight", "blocks.5.attn.to_q.bias", "blocks.5.attn.to_k.weight", "blocks.5.attn.to_k.bias", "blocks.5.attn.to_v.weight", "blocks.5.attn.to_v.bias", "blocks.5.attn.to_out.0.weight", "blocks.5.attn.to_out.0.bias", "blocks.6.attn.to_q.weight", "blocks.6.attn.to_q.bias", "blocks.6.attn.to_k.weight", "blocks.6.attn.to_k.bias", "blocks.6.attn.to_v.weight", "blocks.6.attn.to_v.bias", "blocks.6.attn.to_out.0.weight", "blocks.6.attn.to_out.0.bias", "blocks.7.attn.to_q.weight", "blocks.7.attn.to_q.bias", "blocks.7.attn.to_k.weight", "blocks.7.attn.to_k.bias", "blocks.7.attn.to_v.weight", "blocks.7.attn.to_v.bias", "blocks.7.attn.to_out.0.weight", "blocks.7.attn.to_out.0.bias", "blocks.8.attn.to_q.weight", "blocks.8.attn.to_q.bias", "blocks.8.attn.to_k.weight", "blocks.8.attn.to_k.bias", "blocks.8.attn.to_v.weight", "blocks.8.attn.to_v.bias", "blocks.8.attn.to_out.0.weight", "blocks.8.attn.to_out.0.bias", "blocks.9.attn.to_q.weight", "blocks.9.attn.to_q.bias", "blocks.9.attn.to_k.weight", "blocks.9.attn.to_k.bias", "blocks.9.attn.to_v.weight", "blocks.9.attn.to_v.bias", "blocks.9.attn.to_out.0.weight", "blocks.9.attn.to_out.0.bias", "blocks.10.attn.to_q.weight", "blocks.10.attn.to_q.bias", "blocks.10.attn.to_k.weight", "blocks.10.attn.to_k.bias", "blocks.10.attn.to_v.weight", "blocks.10.attn.to_v.bias", "blocks.10.attn.to_out.0.weight", "blocks.10.attn.to_out.0.bias", "blocks.11.attn.to_q.weight", "blocks.11.attn.to_q.bias", "blocks.11.attn.to_k.weight", "blocks.11.attn.to_k.bias", "blocks.11.attn.to_v.weight", "blocks.11.attn.to_v.bias", "blocks.11.attn.to_out.0.weight", "blocks.11.attn.to_out.0.bias", "blocks.12.attn.to_q.weight", "blocks.12.attn.to_q.bias", "blocks.12.attn.to_k.weight", "blocks.12.attn.to_k.bias", "blocks.12.attn.to_v.weight", "blocks.12.attn.to_v.bias", "blocks.12.attn.to_out.0.weight", "blocks.12.attn.to_out.0.bias", "blocks.13.attn.to_q.weight", "blocks.13.attn.to_q.bias", "blocks.13.attn.to_k.weight", "blocks.13.attn.to_k.bias", "blocks.13.attn.to_v.weight", "blocks.13.attn.to_v.bias", "blocks.13.attn.to_out.0.weight", "blocks.13.attn.to_out.0.bias", "blocks.14.attn.to_q.weight", "blocks.14.attn.to_q.bias", "blocks.14.attn.to_k.weight", "blocks.14.attn.to_k.bias", "blocks.14.attn.to_v.weight", "blocks.14.attn.to_v.bias", "blocks.14.attn.to_out.0.weight", "blocks.14.attn.to_out.0.bias", "blocks.15.attn.to_q.weight", "blocks.15.attn.to_q.bias", "blocks.15.attn.to_k.weight", "blocks.15.attn.to_k.bias", "blocks.15.attn.to_v.weight", "blocks.15.attn.to_v.bias", "blocks.15.attn.to_out.0.weight", "blocks.15.attn.to_out.0.bias", "blocks.16.attn.to_q.weight", "blocks.16.attn.to_q.bias", "blocks.16.attn.to_k.weight", "blocks.16.attn.to_k.bias", "blocks.16.attn.to_v.weight", "blocks.16.attn.to_v.bias", "blocks.16.attn.to_out.0.weight", "blocks.16.attn.to_out.0.bias", "blocks.17.attn.to_q.weight", "blocks.17.attn.to_q.bias", "blocks.17.attn.to_k.weight", "blocks.17.attn.to_k.bias", "blocks.17.attn.to_v.weight", "blocks.17.attn.to_v.bias", "blocks.17.attn.to_out.0.weight", "blocks.17.attn.to_out.0.bias", "blocks.18.attn.to_q.weight", "blocks.18.attn.to_q.bias", "blocks.18.attn.to_k.weight", "blocks.18.attn.to_k.bias", "blocks.18.attn.to_v.weight", "blocks.18.attn.to_v.bias", "blocks.18.attn.to_out.0.weight", "blocks.18.attn.to_out.0.bias", "blocks.19.attn.to_q.weight", "blocks.19.attn.to_q.bias", "blocks.19.attn.to_k.weight", "blocks.19.attn.to_k.bias", "blocks.19.attn.to_v.weight", "blocks.19.attn.to_v.bias", "blocks.19.attn.to_out.0.weight", "blocks.19.attn.to_out.0.bias", "blocks.20.attn.to_q.weight", "blocks.20.attn.to_q.bias", "blocks.20.attn.to_k.weight", "blocks.20.attn.to_k.bias", "blocks.20.attn.to_v.weight", "blocks.20.attn.to_v.bias", "blocks.20.attn.to_out.0.weight", "blocks.20.attn.to_out.0.bias", "blocks.21.attn.to_q.weight", "blocks.21.attn.to_q.bias", "blocks.21.attn.to_k.weight", "blocks.21.attn.to_k.bias", "blocks.21.attn.to_v.weight", "blocks.21.attn.to_v.bias", "blocks.21.attn.to_out.0.weight", "blocks.21.attn.to_out.0.bias", "blocks.22.attn.to_q.weight", "blocks.22.attn.to_q.bias", "blocks.22.attn.to_k.weight", "blocks.22.attn.to_k.bias", "blocks.22.attn.to_v.weight", "blocks.22.attn.to_v.bias", "blocks.22.attn.to_out.0.weight", "blocks.22.attn.to_out.0.bias", "blocks.23.attn.to_q.weight", "blocks.23.attn.to_q.bias", "blocks.23.attn.to_k.weight", "blocks.23.attn.to_k.bias", "blocks.23.attn.to_v.weight", "blocks.23.attn.to_v.bias", "blocks.23.attn.to_out.0.weight", "blocks.23.attn.to_out.0.bias", "blocks.24.attn.to_q.weight", "blocks.24.attn.to_q.bias", "blocks.24.attn.to_k.weight", "blocks.24.attn.to_k.bias", "blocks.24.attn.to_v.weight", "blocks.24.attn.to_v.bias", "blocks.24.attn.to_out.0.weight", "blocks.24.attn.to_out.0.bias", "blocks.25.attn.to_q.weight", "blocks.25.attn.to_q.bias", "blocks.25.attn.to_k.weight", "blocks.25.attn.to_k.bias", "blocks.25.attn.to_v.weight", "blocks.25.attn.to_v.bias", "blocks.25.attn.to_out.0.weight", "blocks.25.attn.to_out.0.bias", "blocks.26.attn.to_q.weight", "blocks.26.attn.to_q.bias", "blocks.26.attn.to_k.weight", "blocks.26.attn.to_k.bias", "blocks.26.attn.to_v.weight", "blocks.26.attn.to_v.bias", "blocks.26.attn.to_out.0.weight", "blocks.26.attn.to_out.0.bias", "blocks.27.attn.to_q.weight", "blocks.27.attn.to_q.bias", "blocks.27.attn.to_k.weight", "blocks.27.attn.to_k.bias", "blocks.27.attn.to_v.weight", "blocks.27.attn.to_v.bias", "blocks.27.attn.to_out.0.weight", "blocks.27.attn.to_out.0.bias". Unexpected key(s) in state_dict: "y_embedder.embedding_table.weight", "x_embedder.proj.weight", "x_embedder.proj.bias", "blocks.0.attn.qkv.weight", "blocks.0.attn.qkv.bias", "blocks.0.attn.proj.weight", "blocks.0.attn.proj.bias", "blocks.1.attn.qkv.weight", "blocks.1.attn.qkv.bias", "blocks.1.attn.proj.weight", "blocks.1.attn.proj.bias", "blocks.2.attn.qkv.weight", "blocks.2.attn.qkv.bias", "blocks.2.attn.proj.weight", "blocks.2.attn.proj.bias", "blocks.3.attn.qkv.weight", "blocks.3.attn.qkv.bias", "blocks.3.attn.proj.weight", "blocks.3.attn.proj.bias", "blocks.4.attn.qkv.weight", "blocks.4.attn.qkv.bias", "blocks.4.attn.proj.weight", "blocks.4.attn.proj.bias", "blocks.5.attn.qkv.weight", "blocks.5.attn.qkv.bias", "blocks.5.attn.proj.weight", "blocks.5.attn.proj.bias", "blocks.6.attn.qkv.weight", "blocks.6.attn.qkv.bias", "blocks.6.attn.proj.weight", "blocks.6.attn.proj.bias", "blocks.7.attn.qkv.weight", "blocks.7.attn.qkv.bias", "blocks.7.attn.proj.weight", "blocks.7.attn.proj.bias", "blocks.8.attn.qkv.weight", "blocks.8.attn.qkv.bias", "blocks.8.attn.proj.weight", "blocks.8.attn.proj.bias", "blocks.9.attn.qkv.weight", "blocks.9.attn.qkv.bias", "blocks.9.attn.proj.weight", "blocks.9.attn.proj.bias", "blocks.10.attn.qkv.weight", "blocks.10.attn.qkv.bias", "blocks.10.attn.proj.weight", "blocks.10.attn.proj.bias", "blocks.11.attn.qkv.weight", "blocks.11.attn.qkv.bias", "blocks.11.attn.proj.weight", "blocks.11.attn.proj.bias", "blocks.12.attn.qkv.weight", "blocks.12.attn.qkv.bias", "blocks.12.attn.proj.weight", "blocks.12.attn.proj.bias", "blocks.13.attn.qkv.weight", "blocks.13.attn.qkv.bias", "blocks.13.attn.proj.weight", "blocks.13.attn.proj.bias", "blocks.14.attn.qkv.weight", "blocks.14.attn.qkv.bias", "blocks.14.attn.proj.weight", "blocks.14.attn.proj.bias", "blocks.15.attn.qkv.weight", "blocks.15.attn.qkv.bias", "blocks.15.attn.proj.weight", "blocks.15.attn.proj.bias", "blocks.16.attn.qkv.weight", "blocks.16.attn.qkv.bias", "blocks.16.attn.proj.weight", "blocks.16.attn.proj.bias", "blocks.17.attn.qkv.weight", "blocks.17.attn.qkv.bias", "blocks.17.attn.proj.weight", "blocks.17.attn.proj.bias", "blocks.18.attn.qkv.weight", "blocks.18.attn.qkv.bias", "blocks.18.attn.proj.weight", "blocks.18.attn.proj.bias", "blocks.19.attn.qkv.weight", "blocks.19.attn.qkv.bias", "blocks.19.attn.proj.weight", "blocks.19.attn.proj.bias", "blocks.20.attn.qkv.weight", "blocks.20.attn.qkv.bias", "blocks.20.attn.proj.weight", "blocks.20.attn.proj.bias", "blocks.21.attn.qkv.weight", "blocks.21.attn.qkv.bias", "blocks.21.attn.proj.weight", "blocks.21.attn.proj.bias", "blocks.22.attn.qkv.weight", "blocks.22.attn.qkv.bias", "blocks.22.attn.proj.weight", "blocks.22.attn.proj.bias", "blocks.23.attn.qkv.weight", "blocks.23.attn.qkv.bias", "blocks.23.attn.proj.weight", "blocks.23.attn.proj.bias", "blocks.24.attn.qkv.weight", "blocks.24.attn.qkv.bias", "blocks.24.attn.proj.weight", "blocks.24.attn.proj.bias", "blocks.25.attn.qkv.weight", "blocks.25.attn.qkv.bias", "blocks.25.attn.proj.weight", "blocks.25.attn.proj.bias", "blocks.26.attn.qkv.weight", "blocks.26.attn.qkv.bias", "blocks.26.attn.proj.weight", "blocks.26.attn.proj.bias", "blocks.27.attn.qkv.weight", "blocks.27.attn.qkv.bias", "blocks.27.attn.proj.weight", "blocks.27.attn.proj.bias". size mismatch for final_layer.linear.weight: copying a param with shape torch.Size([32, 1152]) from checkpoint, the shape in current model is torch.Size([24, 1152]). size mismatch for final_layer.linear.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([24]).

I would appreciate the help to solve it.

tgmorais avatar Mar 05 '24 11:03 tgmorais

Hi, do you use the original DiT pretrained weights? We modified the modeling, so the original weights cannot be loaded. You can save pretrained weights using our training script.

ver217 avatar Mar 06 '24 03:03 ver217

Runing the inerence sample

python sample.py -m "DiT/XL-2" --text "a person is walking on the street" --ckpt /path/to/checkpoint --height 256 --width 256 --fps 10 --sec 5 --disable-cfg

I got the following error:

usage: sample.py [-h] [-m {DiT-XL/2,DiT-XL/4,DiT-XL/8,DiT-L/2,DiT-L/4,DiT-L/8,DiT-B/2,DiT-B/4,DiT-B/8,DiT-S/2,DiT-S/4,DiT-S/8}] [--text TEXT] [--cfg-scale CFG_SCALE] [--num-sampling-steps NUM_SAMPLING_STEPS] [--seed SEED] --ckpt CKPT [-c {raw,vqvae,vae}] [--text_model TEXT_MODEL] [--width WIDTH] [--height HEIGHT] [--fps FPS] [--sec SEC] [--disable-cfg] sample.py: error: argument -m/--model: invalid choice: 'DiT/XL-2' (choose from 'DiT-XL/2', 'DiT-XL/4', 'DiT-XL/8', 'DiT-L/2', 'DiT-L/4', 'DiT-L/8', 'DiT-B/2', 'DiT-B/4', 'DiT-B/8', 'DiT-S/2', 'DiT-S/4', 'DiT-S/8')

Then, I changed to !python sample.py -m "DiT-XL/2" --text "a person is walking on the street" --ckpt pretrained_models/DiT-XL-2-256x256.pt --height 256 --width 256 --fps 10 --sec 5 --disable-cfg

But I got a different error

Traceback (most recent call last): File "/content/Open-Sora/sample.py", line 136, in main(args) File "/content/Open-Sora/sample.py", line 39, in main model.load_state_dict(torch.load(args.ckpt)) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2152, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for DiT: Missing key(s) in state_dict: "video_embedder.proj.weight", "video_embedder.proj.bias", "blocks.0.attn.to_q.weight", "blocks.0.attn.to_q.bias", "blocks.0.attn.to_k.weight", "blocks.0.attn.to_k.bias", "blocks.0.attn.to_v.weight", "blocks.0.attn.to_v.bias", "blocks.0.attn.to_out.0.weight", "blocks.0.attn.to_out.0.bias", "blocks.1.attn.to_q.weight", "blocks.1.attn.to_q.bias", "blocks.1.attn.to_k.weight", "blocks.1.attn.to_k.bias", "blocks.1.attn.to_v.weight", "blocks.1.attn.to_v.bias", "blocks.1.attn.to_out.0.weight", "blocks.1.attn.to_out.0.bias", "blocks.2.attn.to_q.weight", "blocks.2.attn.to_q.bias", "blocks.2.attn.to_k.weight", "blocks.2.attn.to_k.bias", "blocks.2.attn.to_v.weight", "blocks.2.attn.to_v.bias", "blocks.2.attn.to_out.0.weight", "blocks.2.attn.to_out.0.bias", "blocks.3.attn.to_q.weight", "blocks.3.attn.to_q.bias", "blocks.3.attn.to_k.weight", "blocks.3.attn.to_k.bias", "blocks.3.attn.to_v.weight", "blocks.3.attn.to_v.bias", "blocks.3.attn.to_out.0.weight", "blocks.3.attn.to_out.0.bias", "blocks.4.attn.to_q.weight", "blocks.4.attn.to_q.bias", "blocks.4.attn.to_k.weight", "blocks.4.attn.to_k.bias", "blocks.4.attn.to_v.weight", "blocks.4.attn.to_v.bias", "blocks.4.attn.to_out.0.weight", "blocks.4.attn.to_out.0.bias", "blocks.5.attn.to_q.weight", "blocks.5.attn.to_q.bias", "blocks.5.attn.to_k.weight", "blocks.5.attn.to_k.bias", "blocks.5.attn.to_v.weight", "blocks.5.attn.to_v.bias", "blocks.5.attn.to_out.0.weight", "blocks.5.attn.to_out.0.bias", "blocks.6.attn.to_q.weight", "blocks.6.attn.to_q.bias", "blocks.6.attn.to_k.weight", "blocks.6.attn.to_k.bias", "blocks.6.attn.to_v.weight", "blocks.6.attn.to_v.bias", "blocks.6.attn.to_out.0.weight", "blocks.6.attn.to_out.0.bias", "blocks.7.attn.to_q.weight", "blocks.7.attn.to_q.bias", "blocks.7.attn.to_k.weight", "blocks.7.attn.to_k.bias", "blocks.7.attn.to_v.weight", "blocks.7.attn.to_v.bias", "blocks.7.attn.to_out.0.weight", "blocks.7.attn.to_out.0.bias", "blocks.8.attn.to_q.weight", "blocks.8.attn.to_q.bias", "blocks.8.attn.to_k.weight", "blocks.8.attn.to_k.bias", "blocks.8.attn.to_v.weight", "blocks.8.attn.to_v.bias", "blocks.8.attn.to_out.0.weight", "blocks.8.attn.to_out.0.bias", "blocks.9.attn.to_q.weight", "blocks.9.attn.to_q.bias", "blocks.9.attn.to_k.weight", "blocks.9.attn.to_k.bias", "blocks.9.attn.to_v.weight", "blocks.9.attn.to_v.bias", "blocks.9.attn.to_out.0.weight", "blocks.9.attn.to_out.0.bias", "blocks.10.attn.to_q.weight", "blocks.10.attn.to_q.bias", "blocks.10.attn.to_k.weight", "blocks.10.attn.to_k.bias", "blocks.10.attn.to_v.weight", "blocks.10.attn.to_v.bias", "blocks.10.attn.to_out.0.weight", "blocks.10.attn.to_out.0.bias", "blocks.11.attn.to_q.weight", "blocks.11.attn.to_q.bias", "blocks.11.attn.to_k.weight", "blocks.11.attn.to_k.bias", "blocks.11.attn.to_v.weight", "blocks.11.attn.to_v.bias", "blocks.11.attn.to_out.0.weight", "blocks.11.attn.to_out.0.bias", "blocks.12.attn.to_q.weight", "blocks.12.attn.to_q.bias", "blocks.12.attn.to_k.weight", "blocks.12.attn.to_k.bias", "blocks.12.attn.to_v.weight", "blocks.12.attn.to_v.bias", "blocks.12.attn.to_out.0.weight", "blocks.12.attn.to_out.0.bias", "blocks.13.attn.to_q.weight", "blocks.13.attn.to_q.bias", "blocks.13.attn.to_k.weight", "blocks.13.attn.to_k.bias", "blocks.13.attn.to_v.weight", "blocks.13.attn.to_v.bias", "blocks.13.attn.to_out.0.weight", "blocks.13.attn.to_out.0.bias", "blocks.14.attn.to_q.weight", "blocks.14.attn.to_q.bias", "blocks.14.attn.to_k.weight", "blocks.14.attn.to_k.bias", "blocks.14.attn.to_v.weight", "blocks.14.attn.to_v.bias", "blocks.14.attn.to_out.0.weight", "blocks.14.attn.to_out.0.bias", "blocks.15.attn.to_q.weight", "blocks.15.attn.to_q.bias", "blocks.15.attn.to_k.weight", "blocks.15.attn.to_k.bias", "blocks.15.attn.to_v.weight", "blocks.15.attn.to_v.bias", "blocks.15.attn.to_out.0.weight", "blocks.15.attn.to_out.0.bias", "blocks.16.attn.to_q.weight", "blocks.16.attn.to_q.bias", "blocks.16.attn.to_k.weight", "blocks.16.attn.to_k.bias", "blocks.16.attn.to_v.weight", "blocks.16.attn.to_v.bias", "blocks.16.attn.to_out.0.weight", "blocks.16.attn.to_out.0.bias", "blocks.17.attn.to_q.weight", "blocks.17.attn.to_q.bias", "blocks.17.attn.to_k.weight", "blocks.17.attn.to_k.bias", "blocks.17.attn.to_v.weight", "blocks.17.attn.to_v.bias", "blocks.17.attn.to_out.0.weight", "blocks.17.attn.to_out.0.bias", "blocks.18.attn.to_q.weight", "blocks.18.attn.to_q.bias", "blocks.18.attn.to_k.weight", "blocks.18.attn.to_k.bias", "blocks.18.attn.to_v.weight", "blocks.18.attn.to_v.bias", "blocks.18.attn.to_out.0.weight", "blocks.18.attn.to_out.0.bias", "blocks.19.attn.to_q.weight", "blocks.19.attn.to_q.bias", "blocks.19.attn.to_k.weight", "blocks.19.attn.to_k.bias", "blocks.19.attn.to_v.weight", "blocks.19.attn.to_v.bias", "blocks.19.attn.to_out.0.weight", "blocks.19.attn.to_out.0.bias", "blocks.20.attn.to_q.weight", "blocks.20.attn.to_q.bias", "blocks.20.attn.to_k.weight", "blocks.20.attn.to_k.bias", "blocks.20.attn.to_v.weight", "blocks.20.attn.to_v.bias", "blocks.20.attn.to_out.0.weight", "blocks.20.attn.to_out.0.bias", "blocks.21.attn.to_q.weight", "blocks.21.attn.to_q.bias", "blocks.21.attn.to_k.weight", "blocks.21.attn.to_k.bias", "blocks.21.attn.to_v.weight", "blocks.21.attn.to_v.bias", "blocks.21.attn.to_out.0.weight", "blocks.21.attn.to_out.0.bias", "blocks.22.attn.to_q.weight", "blocks.22.attn.to_q.bias", "blocks.22.attn.to_k.weight", "blocks.22.attn.to_k.bias", "blocks.22.attn.to_v.weight", "blocks.22.attn.to_v.bias", "blocks.22.attn.to_out.0.weight", "blocks.22.attn.to_out.0.bias", "blocks.23.attn.to_q.weight", "blocks.23.attn.to_q.bias", "blocks.23.attn.to_k.weight", "blocks.23.attn.to_k.bias", "blocks.23.attn.to_v.weight", "blocks.23.attn.to_v.bias", "blocks.23.attn.to_out.0.weight", "blocks.23.attn.to_out.0.bias", "blocks.24.attn.to_q.weight", "blocks.24.attn.to_q.bias", "blocks.24.attn.to_k.weight", "blocks.24.attn.to_k.bias", "blocks.24.attn.to_v.weight", "blocks.24.attn.to_v.bias", "blocks.24.attn.to_out.0.weight", "blocks.24.attn.to_out.0.bias", "blocks.25.attn.to_q.weight", "blocks.25.attn.to_q.bias", "blocks.25.attn.to_k.weight", "blocks.25.attn.to_k.bias", "blocks.25.attn.to_v.weight", "blocks.25.attn.to_v.bias", "blocks.25.attn.to_out.0.weight", "blocks.25.attn.to_out.0.bias", "blocks.26.attn.to_q.weight", "blocks.26.attn.to_q.bias", "blocks.26.attn.to_k.weight", "blocks.26.attn.to_k.bias", "blocks.26.attn.to_v.weight", "blocks.26.attn.to_v.bias", "blocks.26.attn.to_out.0.weight", "blocks.26.attn.to_out.0.bias", "blocks.27.attn.to_q.weight", "blocks.27.attn.to_q.bias", "blocks.27.attn.to_k.weight", "blocks.27.attn.to_k.bias", "blocks.27.attn.to_v.weight", "blocks.27.attn.to_v.bias", "blocks.27.attn.to_out.0.weight", "blocks.27.attn.to_out.0.bias". Unexpected key(s) in state_dict: "y_embedder.embedding_table.weight", "x_embedder.proj.weight", "x_embedder.proj.bias", "blocks.0.attn.qkv.weight", "blocks.0.attn.qkv.bias", "blocks.0.attn.proj.weight", "blocks.0.attn.proj.bias", "blocks.1.attn.qkv.weight", "blocks.1.attn.qkv.bias", "blocks.1.attn.proj.weight", "blocks.1.attn.proj.bias", "blocks.2.attn.qkv.weight", "blocks.2.attn.qkv.bias", "blocks.2.attn.proj.weight", "blocks.2.attn.proj.bias", "blocks.3.attn.qkv.weight", "blocks.3.attn.qkv.bias", "blocks.3.attn.proj.weight", "blocks.3.attn.proj.bias", "blocks.4.attn.qkv.weight", "blocks.4.attn.qkv.bias", "blocks.4.attn.proj.weight", "blocks.4.attn.proj.bias", "blocks.5.attn.qkv.weight", "blocks.5.attn.qkv.bias", "blocks.5.attn.proj.weight", "blocks.5.attn.proj.bias", "blocks.6.attn.qkv.weight", "blocks.6.attn.qkv.bias", "blocks.6.attn.proj.weight", "blocks.6.attn.proj.bias", "blocks.7.attn.qkv.weight", "blocks.7.attn.qkv.bias", "blocks.7.attn.proj.weight", "blocks.7.attn.proj.bias", "blocks.8.attn.qkv.weight", "blocks.8.attn.qkv.bias", "blocks.8.attn.proj.weight", "blocks.8.attn.proj.bias", "blocks.9.attn.qkv.weight", "blocks.9.attn.qkv.bias", "blocks.9.attn.proj.weight", "blocks.9.attn.proj.bias", "blocks.10.attn.qkv.weight", "blocks.10.attn.qkv.bias", "blocks.10.attn.proj.weight", "blocks.10.attn.proj.bias", "blocks.11.attn.qkv.weight", "blocks.11.attn.qkv.bias", "blocks.11.attn.proj.weight", "blocks.11.attn.proj.bias", "blocks.12.attn.qkv.weight", "blocks.12.attn.qkv.bias", "blocks.12.attn.proj.weight", "blocks.12.attn.proj.bias", "blocks.13.attn.qkv.weight", "blocks.13.attn.qkv.bias", "blocks.13.attn.proj.weight", "blocks.13.attn.proj.bias", "blocks.14.attn.qkv.weight", "blocks.14.attn.qkv.bias", "blocks.14.attn.proj.weight", "blocks.14.attn.proj.bias", "blocks.15.attn.qkv.weight", "blocks.15.attn.qkv.bias", "blocks.15.attn.proj.weight", "blocks.15.attn.proj.bias", "blocks.16.attn.qkv.weight", "blocks.16.attn.qkv.bias", "blocks.16.attn.proj.weight", "blocks.16.attn.proj.bias", "blocks.17.attn.qkv.weight", "blocks.17.attn.qkv.bias", "blocks.17.attn.proj.weight", "blocks.17.attn.proj.bias", "blocks.18.attn.qkv.weight", "blocks.18.attn.qkv.bias", "blocks.18.attn.proj.weight", "blocks.18.attn.proj.bias", "blocks.19.attn.qkv.weight", "blocks.19.attn.qkv.bias", "blocks.19.attn.proj.weight", "blocks.19.attn.proj.bias", "blocks.20.attn.qkv.weight", "blocks.20.attn.qkv.bias", "blocks.20.attn.proj.weight", "blocks.20.attn.proj.bias", "blocks.21.attn.qkv.weight", "blocks.21.attn.qkv.bias", "blocks.21.attn.proj.weight", "blocks.21.attn.proj.bias", "blocks.22.attn.qkv.weight", "blocks.22.attn.qkv.bias", "blocks.22.attn.proj.weight", "blocks.22.attn.proj.bias", "blocks.23.attn.qkv.weight", "blocks.23.attn.qkv.bias", "blocks.23.attn.proj.weight", "blocks.23.attn.proj.bias", "blocks.24.attn.qkv.weight", "blocks.24.attn.qkv.bias", "blocks.24.attn.proj.weight", "blocks.24.attn.proj.bias", "blocks.25.attn.qkv.weight", "blocks.25.attn.qkv.bias", "blocks.25.attn.proj.weight", "blocks.25.attn.proj.bias", "blocks.26.attn.qkv.weight", "blocks.26.attn.qkv.bias", "blocks.26.attn.proj.weight", "blocks.26.attn.proj.bias", "blocks.27.attn.qkv.weight", "blocks.27.attn.qkv.bias", "blocks.27.attn.proj.weight", "blocks.27.attn.proj.bias". size mismatch for final_layer.linear.weight: copying a param with shape torch.Size([32, 1152]) from checkpoint, the shape in current model is torch.Size([24, 1152]). size mismatch for final_layer.linear.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([24]).

I would appreciate the help to solve it.

There are keys' name with video_embedder prefix. Seems they're trained by the network in this repo.

unanan avatar Mar 06 '24 06:03 unanan

May be you trained the DiT-S/8 by default?

jeffchy avatar Mar 08 '24 04:03 jeffchy

We have updated our inference code. Please see here for more instructions.

zhengzangw avatar Mar 18 '24 05:03 zhengzangw