Error on running Inference
Runing the inerence sample
python sample.py -m "DiT/XL-2" --text "a person is walking on the street" --ckpt /path/to/checkpoint --height 256 --width 256 --fps 10 --sec 5 --disable-cfg
I got the following error:
usage: sample.py [-h] [-m {DiT-XL/2,DiT-XL/4,DiT-XL/8,DiT-L/2,DiT-L/4,DiT-L/8,DiT-B/2,DiT-B/4,DiT-B/8,DiT-S/2,DiT-S/4,DiT-S/8}] [--text TEXT] [--cfg-scale CFG_SCALE] [--num-sampling-steps NUM_SAMPLING_STEPS] [--seed SEED] --ckpt CKPT [-c {raw,vqvae,vae}] [--text_model TEXT_MODEL] [--width WIDTH] [--height HEIGHT] [--fps FPS] [--sec SEC] [--disable-cfg] sample.py: error: argument -m/--model: invalid choice: 'DiT/XL-2' (choose from 'DiT-XL/2', 'DiT-XL/4', 'DiT-XL/8', 'DiT-L/2', 'DiT-L/4', 'DiT-L/8', 'DiT-B/2', 'DiT-B/4', 'DiT-B/8', 'DiT-S/2', 'DiT-S/4', 'DiT-S/8')
Then, I changed to
!python sample.py -m "DiT-XL/2" --text "a person is walking on the street" --ckpt pretrained_models/DiT-XL-2-256x256.pt --height 256 --width 256 --fps 10 --sec 5 --disable-cfg
But I got a different error
Traceback (most recent call last): File "/content/Open-Sora/sample.py", line 136, in
main(args) File "/content/Open-Sora/sample.py", line 39, in main model.load_state_dict(torch.load(args.ckpt)) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2152, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for DiT: Missing key(s) in state_dict: "video_embedder.proj.weight", "video_embedder.proj.bias", "blocks.0.attn.to_q.weight", "blocks.0.attn.to_q.bias", "blocks.0.attn.to_k.weight", "blocks.0.attn.to_k.bias", "blocks.0.attn.to_v.weight", "blocks.0.attn.to_v.bias", "blocks.0.attn.to_out.0.weight", "blocks.0.attn.to_out.0.bias", "blocks.1.attn.to_q.weight", "blocks.1.attn.to_q.bias", "blocks.1.attn.to_k.weight", "blocks.1.attn.to_k.bias", "blocks.1.attn.to_v.weight", "blocks.1.attn.to_v.bias", "blocks.1.attn.to_out.0.weight", "blocks.1.attn.to_out.0.bias", "blocks.2.attn.to_q.weight", "blocks.2.attn.to_q.bias", "blocks.2.attn.to_k.weight", "blocks.2.attn.to_k.bias", "blocks.2.attn.to_v.weight", "blocks.2.attn.to_v.bias", "blocks.2.attn.to_out.0.weight", "blocks.2.attn.to_out.0.bias", "blocks.3.attn.to_q.weight", "blocks.3.attn.to_q.bias", "blocks.3.attn.to_k.weight", "blocks.3.attn.to_k.bias", "blocks.3.attn.to_v.weight", "blocks.3.attn.to_v.bias", "blocks.3.attn.to_out.0.weight", "blocks.3.attn.to_out.0.bias", "blocks.4.attn.to_q.weight", "blocks.4.attn.to_q.bias", "blocks.4.attn.to_k.weight", "blocks.4.attn.to_k.bias", "blocks.4.attn.to_v.weight", "blocks.4.attn.to_v.bias", "blocks.4.attn.to_out.0.weight", "blocks.4.attn.to_out.0.bias", "blocks.5.attn.to_q.weight", "blocks.5.attn.to_q.bias", "blocks.5.attn.to_k.weight", "blocks.5.attn.to_k.bias", "blocks.5.attn.to_v.weight", "blocks.5.attn.to_v.bias", "blocks.5.attn.to_out.0.weight", "blocks.5.attn.to_out.0.bias", "blocks.6.attn.to_q.weight", "blocks.6.attn.to_q.bias", "blocks.6.attn.to_k.weight", "blocks.6.attn.to_k.bias", "blocks.6.attn.to_v.weight", "blocks.6.attn.to_v.bias", "blocks.6.attn.to_out.0.weight", "blocks.6.attn.to_out.0.bias", "blocks.7.attn.to_q.weight", "blocks.7.attn.to_q.bias", "blocks.7.attn.to_k.weight", "blocks.7.attn.to_k.bias", "blocks.7.attn.to_v.weight", "blocks.7.attn.to_v.bias", "blocks.7.attn.to_out.0.weight", "blocks.7.attn.to_out.0.bias", "blocks.8.attn.to_q.weight", "blocks.8.attn.to_q.bias", "blocks.8.attn.to_k.weight", "blocks.8.attn.to_k.bias", "blocks.8.attn.to_v.weight", "blocks.8.attn.to_v.bias", "blocks.8.attn.to_out.0.weight", "blocks.8.attn.to_out.0.bias", "blocks.9.attn.to_q.weight", "blocks.9.attn.to_q.bias", "blocks.9.attn.to_k.weight", "blocks.9.attn.to_k.bias", "blocks.9.attn.to_v.weight", "blocks.9.attn.to_v.bias", "blocks.9.attn.to_out.0.weight", "blocks.9.attn.to_out.0.bias", "blocks.10.attn.to_q.weight", "blocks.10.attn.to_q.bias", "blocks.10.attn.to_k.weight", "blocks.10.attn.to_k.bias", "blocks.10.attn.to_v.weight", "blocks.10.attn.to_v.bias", "blocks.10.attn.to_out.0.weight", "blocks.10.attn.to_out.0.bias", "blocks.11.attn.to_q.weight", "blocks.11.attn.to_q.bias", "blocks.11.attn.to_k.weight", "blocks.11.attn.to_k.bias", "blocks.11.attn.to_v.weight", "blocks.11.attn.to_v.bias", "blocks.11.attn.to_out.0.weight", "blocks.11.attn.to_out.0.bias", "blocks.12.attn.to_q.weight", "blocks.12.attn.to_q.bias", "blocks.12.attn.to_k.weight", "blocks.12.attn.to_k.bias", "blocks.12.attn.to_v.weight", "blocks.12.attn.to_v.bias", "blocks.12.attn.to_out.0.weight", "blocks.12.attn.to_out.0.bias", "blocks.13.attn.to_q.weight", "blocks.13.attn.to_q.bias", "blocks.13.attn.to_k.weight", "blocks.13.attn.to_k.bias", "blocks.13.attn.to_v.weight", "blocks.13.attn.to_v.bias", "blocks.13.attn.to_out.0.weight", "blocks.13.attn.to_out.0.bias", "blocks.14.attn.to_q.weight", "blocks.14.attn.to_q.bias", "blocks.14.attn.to_k.weight", "blocks.14.attn.to_k.bias", "blocks.14.attn.to_v.weight", "blocks.14.attn.to_v.bias", "blocks.14.attn.to_out.0.weight", "blocks.14.attn.to_out.0.bias", "blocks.15.attn.to_q.weight", "blocks.15.attn.to_q.bias", "blocks.15.attn.to_k.weight", "blocks.15.attn.to_k.bias", "blocks.15.attn.to_v.weight", "blocks.15.attn.to_v.bias", "blocks.15.attn.to_out.0.weight", "blocks.15.attn.to_out.0.bias", "blocks.16.attn.to_q.weight", "blocks.16.attn.to_q.bias", "blocks.16.attn.to_k.weight", "blocks.16.attn.to_k.bias", "blocks.16.attn.to_v.weight", "blocks.16.attn.to_v.bias", "blocks.16.attn.to_out.0.weight", "blocks.16.attn.to_out.0.bias", "blocks.17.attn.to_q.weight", "blocks.17.attn.to_q.bias", "blocks.17.attn.to_k.weight", "blocks.17.attn.to_k.bias", "blocks.17.attn.to_v.weight", "blocks.17.attn.to_v.bias", "blocks.17.attn.to_out.0.weight", "blocks.17.attn.to_out.0.bias", "blocks.18.attn.to_q.weight", "blocks.18.attn.to_q.bias", "blocks.18.attn.to_k.weight", "blocks.18.attn.to_k.bias", "blocks.18.attn.to_v.weight", "blocks.18.attn.to_v.bias", "blocks.18.attn.to_out.0.weight", "blocks.18.attn.to_out.0.bias", "blocks.19.attn.to_q.weight", "blocks.19.attn.to_q.bias", "blocks.19.attn.to_k.weight", "blocks.19.attn.to_k.bias", "blocks.19.attn.to_v.weight", "blocks.19.attn.to_v.bias", "blocks.19.attn.to_out.0.weight", "blocks.19.attn.to_out.0.bias", "blocks.20.attn.to_q.weight", "blocks.20.attn.to_q.bias", "blocks.20.attn.to_k.weight", "blocks.20.attn.to_k.bias", "blocks.20.attn.to_v.weight", "blocks.20.attn.to_v.bias", "blocks.20.attn.to_out.0.weight", "blocks.20.attn.to_out.0.bias", "blocks.21.attn.to_q.weight", "blocks.21.attn.to_q.bias", "blocks.21.attn.to_k.weight", "blocks.21.attn.to_k.bias", "blocks.21.attn.to_v.weight", "blocks.21.attn.to_v.bias", "blocks.21.attn.to_out.0.weight", "blocks.21.attn.to_out.0.bias", "blocks.22.attn.to_q.weight", "blocks.22.attn.to_q.bias", "blocks.22.attn.to_k.weight", "blocks.22.attn.to_k.bias", "blocks.22.attn.to_v.weight", "blocks.22.attn.to_v.bias", "blocks.22.attn.to_out.0.weight", "blocks.22.attn.to_out.0.bias", "blocks.23.attn.to_q.weight", "blocks.23.attn.to_q.bias", "blocks.23.attn.to_k.weight", "blocks.23.attn.to_k.bias", "blocks.23.attn.to_v.weight", "blocks.23.attn.to_v.bias", "blocks.23.attn.to_out.0.weight", "blocks.23.attn.to_out.0.bias", "blocks.24.attn.to_q.weight", "blocks.24.attn.to_q.bias", "blocks.24.attn.to_k.weight", "blocks.24.attn.to_k.bias", "blocks.24.attn.to_v.weight", "blocks.24.attn.to_v.bias", "blocks.24.attn.to_out.0.weight", "blocks.24.attn.to_out.0.bias", "blocks.25.attn.to_q.weight", "blocks.25.attn.to_q.bias", "blocks.25.attn.to_k.weight", "blocks.25.attn.to_k.bias", "blocks.25.attn.to_v.weight", "blocks.25.attn.to_v.bias", "blocks.25.attn.to_out.0.weight", "blocks.25.attn.to_out.0.bias", "blocks.26.attn.to_q.weight", "blocks.26.attn.to_q.bias", "blocks.26.attn.to_k.weight", "blocks.26.attn.to_k.bias", "blocks.26.attn.to_v.weight", "blocks.26.attn.to_v.bias", "blocks.26.attn.to_out.0.weight", "blocks.26.attn.to_out.0.bias", "blocks.27.attn.to_q.weight", "blocks.27.attn.to_q.bias", "blocks.27.attn.to_k.weight", "blocks.27.attn.to_k.bias", "blocks.27.attn.to_v.weight", "blocks.27.attn.to_v.bias", "blocks.27.attn.to_out.0.weight", "blocks.27.attn.to_out.0.bias". Unexpected key(s) in state_dict: "y_embedder.embedding_table.weight", "x_embedder.proj.weight", "x_embedder.proj.bias", "blocks.0.attn.qkv.weight", "blocks.0.attn.qkv.bias", "blocks.0.attn.proj.weight", "blocks.0.attn.proj.bias", "blocks.1.attn.qkv.weight", "blocks.1.attn.qkv.bias", "blocks.1.attn.proj.weight", "blocks.1.attn.proj.bias", "blocks.2.attn.qkv.weight", "blocks.2.attn.qkv.bias", "blocks.2.attn.proj.weight", "blocks.2.attn.proj.bias", "blocks.3.attn.qkv.weight", "blocks.3.attn.qkv.bias", "blocks.3.attn.proj.weight", "blocks.3.attn.proj.bias", "blocks.4.attn.qkv.weight", "blocks.4.attn.qkv.bias", "blocks.4.attn.proj.weight", "blocks.4.attn.proj.bias", "blocks.5.attn.qkv.weight", "blocks.5.attn.qkv.bias", "blocks.5.attn.proj.weight", "blocks.5.attn.proj.bias", "blocks.6.attn.qkv.weight", "blocks.6.attn.qkv.bias", "blocks.6.attn.proj.weight", "blocks.6.attn.proj.bias", "blocks.7.attn.qkv.weight", "blocks.7.attn.qkv.bias", "blocks.7.attn.proj.weight", "blocks.7.attn.proj.bias", "blocks.8.attn.qkv.weight", "blocks.8.attn.qkv.bias", "blocks.8.attn.proj.weight", "blocks.8.attn.proj.bias", "blocks.9.attn.qkv.weight", "blocks.9.attn.qkv.bias", "blocks.9.attn.proj.weight", "blocks.9.attn.proj.bias", "blocks.10.attn.qkv.weight", "blocks.10.attn.qkv.bias", "blocks.10.attn.proj.weight", "blocks.10.attn.proj.bias", "blocks.11.attn.qkv.weight", "blocks.11.attn.qkv.bias", "blocks.11.attn.proj.weight", "blocks.11.attn.proj.bias", "blocks.12.attn.qkv.weight", "blocks.12.attn.qkv.bias", "blocks.12.attn.proj.weight", "blocks.12.attn.proj.bias", "blocks.13.attn.qkv.weight", "blocks.13.attn.qkv.bias", "blocks.13.attn.proj.weight", "blocks.13.attn.proj.bias", "blocks.14.attn.qkv.weight", "blocks.14.attn.qkv.bias", "blocks.14.attn.proj.weight", "blocks.14.attn.proj.bias", "blocks.15.attn.qkv.weight", "blocks.15.attn.qkv.bias", "blocks.15.attn.proj.weight", "blocks.15.attn.proj.bias", "blocks.16.attn.qkv.weight", "blocks.16.attn.qkv.bias", "blocks.16.attn.proj.weight", "blocks.16.attn.proj.bias", "blocks.17.attn.qkv.weight", "blocks.17.attn.qkv.bias", "blocks.17.attn.proj.weight", "blocks.17.attn.proj.bias", "blocks.18.attn.qkv.weight", "blocks.18.attn.qkv.bias", "blocks.18.attn.proj.weight", "blocks.18.attn.proj.bias", "blocks.19.attn.qkv.weight", "blocks.19.attn.qkv.bias", "blocks.19.attn.proj.weight", "blocks.19.attn.proj.bias", "blocks.20.attn.qkv.weight", "blocks.20.attn.qkv.bias", "blocks.20.attn.proj.weight", "blocks.20.attn.proj.bias", "blocks.21.attn.qkv.weight", "blocks.21.attn.qkv.bias", "blocks.21.attn.proj.weight", "blocks.21.attn.proj.bias", "blocks.22.attn.qkv.weight", "blocks.22.attn.qkv.bias", "blocks.22.attn.proj.weight", "blocks.22.attn.proj.bias", "blocks.23.attn.qkv.weight", "blocks.23.attn.qkv.bias", "blocks.23.attn.proj.weight", "blocks.23.attn.proj.bias", "blocks.24.attn.qkv.weight", "blocks.24.attn.qkv.bias", "blocks.24.attn.proj.weight", "blocks.24.attn.proj.bias", "blocks.25.attn.qkv.weight", "blocks.25.attn.qkv.bias", "blocks.25.attn.proj.weight", "blocks.25.attn.proj.bias", "blocks.26.attn.qkv.weight", "blocks.26.attn.qkv.bias", "blocks.26.attn.proj.weight", "blocks.26.attn.proj.bias", "blocks.27.attn.qkv.weight", "blocks.27.attn.qkv.bias", "blocks.27.attn.proj.weight", "blocks.27.attn.proj.bias". size mismatch for final_layer.linear.weight: copying a param with shape torch.Size([32, 1152]) from checkpoint, the shape in current model is torch.Size([24, 1152]). size mismatch for final_layer.linear.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([24]).
I would appreciate the help to solve it.
Hi, do you use the original DiT pretrained weights? We modified the modeling, so the original weights cannot be loaded. You can save pretrained weights using our training script.
Runing the inerence sample
python sample.py -m "DiT/XL-2" --text "a person is walking on the street" --ckpt /path/to/checkpoint --height 256 --width 256 --fps 10 --sec 5 --disable-cfgI got the following error:
usage: sample.py [-h] [-m {DiT-XL/2,DiT-XL/4,DiT-XL/8,DiT-L/2,DiT-L/4,DiT-L/8,DiT-B/2,DiT-B/4,DiT-B/8,DiT-S/2,DiT-S/4,DiT-S/8}] [--text TEXT] [--cfg-scale CFG_SCALE] [--num-sampling-steps NUM_SAMPLING_STEPS] [--seed SEED] --ckpt CKPT [-c {raw,vqvae,vae}] [--text_model TEXT_MODEL] [--width WIDTH] [--height HEIGHT] [--fps FPS] [--sec SEC] [--disable-cfg] sample.py: error: argument -m/--model: invalid choice: 'DiT/XL-2' (choose from 'DiT-XL/2', 'DiT-XL/4', 'DiT-XL/8', 'DiT-L/2', 'DiT-L/4', 'DiT-L/8', 'DiT-B/2', 'DiT-B/4', 'DiT-B/8', 'DiT-S/2', 'DiT-S/4', 'DiT-S/8')
Then, I changed to
!python sample.py -m "DiT-XL/2" --text "a person is walking on the street" --ckpt pretrained_models/DiT-XL-2-256x256.pt --height 256 --width 256 --fps 10 --sec 5 --disable-cfgBut I got a different error
Traceback (most recent call last): File "/content/Open-Sora/sample.py", line 136, in main(args) File "/content/Open-Sora/sample.py", line 39, in main model.load_state_dict(torch.load(args.ckpt)) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2152, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for DiT: Missing key(s) in state_dict: "video_embedder.proj.weight", "video_embedder.proj.bias", "blocks.0.attn.to_q.weight", "blocks.0.attn.to_q.bias", "blocks.0.attn.to_k.weight", "blocks.0.attn.to_k.bias", "blocks.0.attn.to_v.weight", "blocks.0.attn.to_v.bias", "blocks.0.attn.to_out.0.weight", "blocks.0.attn.to_out.0.bias", "blocks.1.attn.to_q.weight", "blocks.1.attn.to_q.bias", "blocks.1.attn.to_k.weight", "blocks.1.attn.to_k.bias", "blocks.1.attn.to_v.weight", "blocks.1.attn.to_v.bias", "blocks.1.attn.to_out.0.weight", "blocks.1.attn.to_out.0.bias", "blocks.2.attn.to_q.weight", "blocks.2.attn.to_q.bias", "blocks.2.attn.to_k.weight", "blocks.2.attn.to_k.bias", "blocks.2.attn.to_v.weight", "blocks.2.attn.to_v.bias", "blocks.2.attn.to_out.0.weight", "blocks.2.attn.to_out.0.bias", "blocks.3.attn.to_q.weight", "blocks.3.attn.to_q.bias", "blocks.3.attn.to_k.weight", "blocks.3.attn.to_k.bias", "blocks.3.attn.to_v.weight", "blocks.3.attn.to_v.bias", "blocks.3.attn.to_out.0.weight", "blocks.3.attn.to_out.0.bias", "blocks.4.attn.to_q.weight", "blocks.4.attn.to_q.bias", "blocks.4.attn.to_k.weight", "blocks.4.attn.to_k.bias", "blocks.4.attn.to_v.weight", "blocks.4.attn.to_v.bias", "blocks.4.attn.to_out.0.weight", "blocks.4.attn.to_out.0.bias", "blocks.5.attn.to_q.weight", "blocks.5.attn.to_q.bias", "blocks.5.attn.to_k.weight", "blocks.5.attn.to_k.bias", "blocks.5.attn.to_v.weight", "blocks.5.attn.to_v.bias", "blocks.5.attn.to_out.0.weight", "blocks.5.attn.to_out.0.bias", "blocks.6.attn.to_q.weight", "blocks.6.attn.to_q.bias", "blocks.6.attn.to_k.weight", "blocks.6.attn.to_k.bias", "blocks.6.attn.to_v.weight", "blocks.6.attn.to_v.bias", "blocks.6.attn.to_out.0.weight", "blocks.6.attn.to_out.0.bias", "blocks.7.attn.to_q.weight", "blocks.7.attn.to_q.bias", "blocks.7.attn.to_k.weight", "blocks.7.attn.to_k.bias", "blocks.7.attn.to_v.weight", "blocks.7.attn.to_v.bias", "blocks.7.attn.to_out.0.weight", "blocks.7.attn.to_out.0.bias", "blocks.8.attn.to_q.weight", "blocks.8.attn.to_q.bias", "blocks.8.attn.to_k.weight", "blocks.8.attn.to_k.bias", "blocks.8.attn.to_v.weight", "blocks.8.attn.to_v.bias", "blocks.8.attn.to_out.0.weight", "blocks.8.attn.to_out.0.bias", "blocks.9.attn.to_q.weight", "blocks.9.attn.to_q.bias", "blocks.9.attn.to_k.weight", "blocks.9.attn.to_k.bias", "blocks.9.attn.to_v.weight", "blocks.9.attn.to_v.bias", "blocks.9.attn.to_out.0.weight", "blocks.9.attn.to_out.0.bias", "blocks.10.attn.to_q.weight", "blocks.10.attn.to_q.bias", "blocks.10.attn.to_k.weight", "blocks.10.attn.to_k.bias", "blocks.10.attn.to_v.weight", "blocks.10.attn.to_v.bias", "blocks.10.attn.to_out.0.weight", "blocks.10.attn.to_out.0.bias", "blocks.11.attn.to_q.weight", "blocks.11.attn.to_q.bias", "blocks.11.attn.to_k.weight", "blocks.11.attn.to_k.bias", "blocks.11.attn.to_v.weight", "blocks.11.attn.to_v.bias", "blocks.11.attn.to_out.0.weight", "blocks.11.attn.to_out.0.bias", "blocks.12.attn.to_q.weight", "blocks.12.attn.to_q.bias", "blocks.12.attn.to_k.weight", "blocks.12.attn.to_k.bias", "blocks.12.attn.to_v.weight", "blocks.12.attn.to_v.bias", "blocks.12.attn.to_out.0.weight", "blocks.12.attn.to_out.0.bias", "blocks.13.attn.to_q.weight", "blocks.13.attn.to_q.bias", "blocks.13.attn.to_k.weight", "blocks.13.attn.to_k.bias", "blocks.13.attn.to_v.weight", "blocks.13.attn.to_v.bias", "blocks.13.attn.to_out.0.weight", "blocks.13.attn.to_out.0.bias", "blocks.14.attn.to_q.weight", "blocks.14.attn.to_q.bias", "blocks.14.attn.to_k.weight", "blocks.14.attn.to_k.bias", "blocks.14.attn.to_v.weight", "blocks.14.attn.to_v.bias", "blocks.14.attn.to_out.0.weight", "blocks.14.attn.to_out.0.bias", "blocks.15.attn.to_q.weight", "blocks.15.attn.to_q.bias", "blocks.15.attn.to_k.weight", "blocks.15.attn.to_k.bias", "blocks.15.attn.to_v.weight", "blocks.15.attn.to_v.bias", "blocks.15.attn.to_out.0.weight", "blocks.15.attn.to_out.0.bias", "blocks.16.attn.to_q.weight", "blocks.16.attn.to_q.bias", "blocks.16.attn.to_k.weight", "blocks.16.attn.to_k.bias", "blocks.16.attn.to_v.weight", "blocks.16.attn.to_v.bias", "blocks.16.attn.to_out.0.weight", "blocks.16.attn.to_out.0.bias", "blocks.17.attn.to_q.weight", "blocks.17.attn.to_q.bias", "blocks.17.attn.to_k.weight", "blocks.17.attn.to_k.bias", "blocks.17.attn.to_v.weight", "blocks.17.attn.to_v.bias", "blocks.17.attn.to_out.0.weight", "blocks.17.attn.to_out.0.bias", "blocks.18.attn.to_q.weight", "blocks.18.attn.to_q.bias", "blocks.18.attn.to_k.weight", "blocks.18.attn.to_k.bias", "blocks.18.attn.to_v.weight", "blocks.18.attn.to_v.bias", "blocks.18.attn.to_out.0.weight", "blocks.18.attn.to_out.0.bias", "blocks.19.attn.to_q.weight", "blocks.19.attn.to_q.bias", "blocks.19.attn.to_k.weight", "blocks.19.attn.to_k.bias", "blocks.19.attn.to_v.weight", "blocks.19.attn.to_v.bias", "blocks.19.attn.to_out.0.weight", "blocks.19.attn.to_out.0.bias", "blocks.20.attn.to_q.weight", "blocks.20.attn.to_q.bias", "blocks.20.attn.to_k.weight", "blocks.20.attn.to_k.bias", "blocks.20.attn.to_v.weight", "blocks.20.attn.to_v.bias", "blocks.20.attn.to_out.0.weight", "blocks.20.attn.to_out.0.bias", "blocks.21.attn.to_q.weight", "blocks.21.attn.to_q.bias", "blocks.21.attn.to_k.weight", "blocks.21.attn.to_k.bias", "blocks.21.attn.to_v.weight", "blocks.21.attn.to_v.bias", "blocks.21.attn.to_out.0.weight", "blocks.21.attn.to_out.0.bias", "blocks.22.attn.to_q.weight", "blocks.22.attn.to_q.bias", "blocks.22.attn.to_k.weight", "blocks.22.attn.to_k.bias", "blocks.22.attn.to_v.weight", "blocks.22.attn.to_v.bias", "blocks.22.attn.to_out.0.weight", "blocks.22.attn.to_out.0.bias", "blocks.23.attn.to_q.weight", "blocks.23.attn.to_q.bias", "blocks.23.attn.to_k.weight", "blocks.23.attn.to_k.bias", "blocks.23.attn.to_v.weight", "blocks.23.attn.to_v.bias", "blocks.23.attn.to_out.0.weight", "blocks.23.attn.to_out.0.bias", "blocks.24.attn.to_q.weight", "blocks.24.attn.to_q.bias", "blocks.24.attn.to_k.weight", "blocks.24.attn.to_k.bias", "blocks.24.attn.to_v.weight", "blocks.24.attn.to_v.bias", "blocks.24.attn.to_out.0.weight", "blocks.24.attn.to_out.0.bias", "blocks.25.attn.to_q.weight", "blocks.25.attn.to_q.bias", "blocks.25.attn.to_k.weight", "blocks.25.attn.to_k.bias", "blocks.25.attn.to_v.weight", "blocks.25.attn.to_v.bias", "blocks.25.attn.to_out.0.weight", "blocks.25.attn.to_out.0.bias", "blocks.26.attn.to_q.weight", "blocks.26.attn.to_q.bias", "blocks.26.attn.to_k.weight", "blocks.26.attn.to_k.bias", "blocks.26.attn.to_v.weight", "blocks.26.attn.to_v.bias", "blocks.26.attn.to_out.0.weight", "blocks.26.attn.to_out.0.bias", "blocks.27.attn.to_q.weight", "blocks.27.attn.to_q.bias", "blocks.27.attn.to_k.weight", "blocks.27.attn.to_k.bias", "blocks.27.attn.to_v.weight", "blocks.27.attn.to_v.bias", "blocks.27.attn.to_out.0.weight", "blocks.27.attn.to_out.0.bias". Unexpected key(s) in state_dict: "y_embedder.embedding_table.weight", "x_embedder.proj.weight", "x_embedder.proj.bias", "blocks.0.attn.qkv.weight", "blocks.0.attn.qkv.bias", "blocks.0.attn.proj.weight", "blocks.0.attn.proj.bias", "blocks.1.attn.qkv.weight", "blocks.1.attn.qkv.bias", "blocks.1.attn.proj.weight", "blocks.1.attn.proj.bias", "blocks.2.attn.qkv.weight", "blocks.2.attn.qkv.bias", "blocks.2.attn.proj.weight", "blocks.2.attn.proj.bias", "blocks.3.attn.qkv.weight", "blocks.3.attn.qkv.bias", "blocks.3.attn.proj.weight", "blocks.3.attn.proj.bias", "blocks.4.attn.qkv.weight", "blocks.4.attn.qkv.bias", "blocks.4.attn.proj.weight", "blocks.4.attn.proj.bias", "blocks.5.attn.qkv.weight", "blocks.5.attn.qkv.bias", "blocks.5.attn.proj.weight", "blocks.5.attn.proj.bias", "blocks.6.attn.qkv.weight", "blocks.6.attn.qkv.bias", "blocks.6.attn.proj.weight", "blocks.6.attn.proj.bias", "blocks.7.attn.qkv.weight", "blocks.7.attn.qkv.bias", "blocks.7.attn.proj.weight", "blocks.7.attn.proj.bias", "blocks.8.attn.qkv.weight", "blocks.8.attn.qkv.bias", "blocks.8.attn.proj.weight", "blocks.8.attn.proj.bias", "blocks.9.attn.qkv.weight", "blocks.9.attn.qkv.bias", "blocks.9.attn.proj.weight", "blocks.9.attn.proj.bias", "blocks.10.attn.qkv.weight", "blocks.10.attn.qkv.bias", "blocks.10.attn.proj.weight", "blocks.10.attn.proj.bias", "blocks.11.attn.qkv.weight", "blocks.11.attn.qkv.bias", "blocks.11.attn.proj.weight", "blocks.11.attn.proj.bias", "blocks.12.attn.qkv.weight", "blocks.12.attn.qkv.bias", "blocks.12.attn.proj.weight", "blocks.12.attn.proj.bias", "blocks.13.attn.qkv.weight", "blocks.13.attn.qkv.bias", "blocks.13.attn.proj.weight", "blocks.13.attn.proj.bias", "blocks.14.attn.qkv.weight", "blocks.14.attn.qkv.bias", "blocks.14.attn.proj.weight", "blocks.14.attn.proj.bias", "blocks.15.attn.qkv.weight", "blocks.15.attn.qkv.bias", "blocks.15.attn.proj.weight", "blocks.15.attn.proj.bias", "blocks.16.attn.qkv.weight", "blocks.16.attn.qkv.bias", "blocks.16.attn.proj.weight", "blocks.16.attn.proj.bias", "blocks.17.attn.qkv.weight", "blocks.17.attn.qkv.bias", "blocks.17.attn.proj.weight", "blocks.17.attn.proj.bias", "blocks.18.attn.qkv.weight", "blocks.18.attn.qkv.bias", "blocks.18.attn.proj.weight", "blocks.18.attn.proj.bias", "blocks.19.attn.qkv.weight", "blocks.19.attn.qkv.bias", "blocks.19.attn.proj.weight", "blocks.19.attn.proj.bias", "blocks.20.attn.qkv.weight", "blocks.20.attn.qkv.bias", "blocks.20.attn.proj.weight", "blocks.20.attn.proj.bias", "blocks.21.attn.qkv.weight", "blocks.21.attn.qkv.bias", "blocks.21.attn.proj.weight", "blocks.21.attn.proj.bias", "blocks.22.attn.qkv.weight", "blocks.22.attn.qkv.bias", "blocks.22.attn.proj.weight", "blocks.22.attn.proj.bias", "blocks.23.attn.qkv.weight", "blocks.23.attn.qkv.bias", "blocks.23.attn.proj.weight", "blocks.23.attn.proj.bias", "blocks.24.attn.qkv.weight", "blocks.24.attn.qkv.bias", "blocks.24.attn.proj.weight", "blocks.24.attn.proj.bias", "blocks.25.attn.qkv.weight", "blocks.25.attn.qkv.bias", "blocks.25.attn.proj.weight", "blocks.25.attn.proj.bias", "blocks.26.attn.qkv.weight", "blocks.26.attn.qkv.bias", "blocks.26.attn.proj.weight", "blocks.26.attn.proj.bias", "blocks.27.attn.qkv.weight", "blocks.27.attn.qkv.bias", "blocks.27.attn.proj.weight", "blocks.27.attn.proj.bias". size mismatch for final_layer.linear.weight: copying a param with shape torch.Size([32, 1152]) from checkpoint, the shape in current model is torch.Size([24, 1152]). size mismatch for final_layer.linear.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([24]).
I would appreciate the help to solve it.
There are keys' name with video_embedder prefix. Seems they're trained by the network in this repo.
May be you trained the DiT-S/8 by default?
We have updated our inference code. Please see here for more instructions.