Wonder3D icon indicating copy to clipboard operation
Wonder3D copied to clipboard

Stage 2 training BUG: following keys are missing

Open RenieWell opened this issue 1 year ago • 8 comments

Thanks for sharing this work with us!

I have trained the stage 1 model with expected performance with the Unet weight from the pretrained model pretrained_model_name_or_path: 'lambdalabs/sd-image-variations-diffusers'

But after the training, the model trained in stage 1 cann't be loaded by the stage 2 code, and I got the following errors: load pre-trained unet from ./outputs/wonder3D-mix-vanila/checkpoint/ Traceback (most recent call last): File "/data/.conda/envs/marigold/lib/python3.10/runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "/data/.conda/envs/marigold/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/data/.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in <module> cli.main() File "/data/.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main run() File "/data/.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file runpy.run_path(target, run_name="__main__") File "/data/.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path return _run_module_code(code, init_globals, run_name, File "/data/.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code _run_code(code, mod_globals, init_globals, File "/data/.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code exec(code, run_globals) File "/data/Wonder3D/train_mvdiffusion_joint.py", line 773, in <module> main(cfg) File "/data/Wonder3D/train_mvdiffusion_joint.py", line 251, in main unet = UNetMV2DConditionModel.from_pretrained(cfg.pretrained_unet_path, subfolder="unet", revision=cfg.revision, **cfg.unet_from_pretrained_kwargs) File "/data/.conda/envs/marigold/lib/python3.10/site-packages/diffusers/models/modeling_utils.py", line 604, in from_pretrained raise ValueError( ValueError: Cannot load <class 'mvdiffusion.models.unet_mv2d_condition.UNetMV2DConditionModel'> from ./outputs/wonder3D-mix-vanila/checkpoint/ because the following keys are missing: up_blocks.3.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.weight, up_blocks.2.attentions.2.transformer_blocks.0.attn_joint_mid.to_v.weight, mid_block.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.bias, up_blocks.2.attentions.2.transformer_blocks.0.attn_joint_mid.to_k.weight, down_blocks.1.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.weight, mid_block.attentions.0.transformer_blocks.0.norm_joint_mid.bias, down_blocks.2.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.weight, down_blocks.1.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.weight, up_blocks.2.attentions.0.transformer_blocks.0.norm_joint_mid.bias, down_blocks.0.attentions.1.transformer_blocks.0.norm_joint_mid.weight, down_blocks.1.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.bias, down_blocks.0.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.bias, down_blocks.2.attentions.1.transformer_blocks.0.attn_joint_mid.to_q.weight, up_blocks.2.attentions.0.transformer_blocks.0.norm_joint_mid.weight, down_blocks.1.attentions.0.transformer_blocks.0.attn_joint_mid.to_k.weight, up_blocks.1.attentions.0.transformer_blocks.0.norm_joint_mid.bias, down_blocks.2.attentions.0.transformer_blocks.0.norm_joint_mid.weight, up_blocks.2.attentions.2.transformer_blocks.0.attn_joint_mid.to_q.weight, down_blocks.0.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.weight, up_blocks.2.attentions.0.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.3.attentions.1.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.1.attentions.2.transformer_blocks.0.attn_joint_mid.to_q.weight, up_blocks.2.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.weight, down_blocks.2.attentions.0.transformer_blocks.0.norm_joint_mid.bias, up_blocks.3.attentions.0.transformer_blocks.0.norm_joint_mid.weight, up_blocks.3.attentions.2.transformer_blocks.0.attn_joint_mid.to_q.weight, down_blocks.1.attentions.0.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.1.attentions.1.transformer_blocks.0.norm_joint_mid.bias, up_blocks.2.attentions.1.transformer_blocks.0.attn_joint_mid.to_k.weight, up_blocks.2.attentions.2.transformer_blocks.0.attn_joint_mid.to_out.0.bias, up_blocks.3.attentions.1.transformer_blocks.0.norm_joint_mid.weight, down_blocks.1.attentions.1.transformer_blocks.0.norm_joint_mid.bias, up_blocks.2.attentions.1.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.1.attentions.2.transformer_blocks.0.norm_joint_mid.bias, up_blocks.2.attentions.0.transformer_blocks.0.attn_joint_mid.to_q.weight, up_blocks.2.attentions.0.transformer_blocks.0.attn_joint_mid.to_k.weight, up_blocks.3.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.bias, up_blocks.3.attentions.1.transformer_blocks.0.attn_joint_mid.to_k.weight, up_blocks.3.attentions.2.transformer_blocks.0.attn_joint_mid.to_out.0.weight, up_blocks.1.attentions.1.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.1.attentions.0.transformer_blocks.0.attn_joint_mid.to_k.weight, down_blocks.1.attentions.0.transformer_blocks.0.attn_joint_mid.to_q.weight, down_blocks.0.attentions.1.transformer_blocks.0.norm_joint_mid.bias, down_blocks.0.attentions.0.transformer_blocks.0.attn_joint_mid.to_v.weight, mid_block.attentions.0.transformer_blocks.0.attn_joint_mid.to_q.weight, up_blocks.3.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.weight, down_blocks.0.attentions.1.transformer_blocks.0.attn_joint_mid.to_q.weight, up_blocks.2.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.bias, up_blocks.1.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.weight, down_blocks.2.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.bias, down_blocks.0.attentions.0.transformer_blocks.0.attn_joint_mid.to_k.weight, down_blocks.0.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.bias, down_blocks.2.attentions.1.transformer_blocks.0.norm_joint_mid.weight, up_blocks.1.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.bias, down_blocks.0.attentions.0.transformer_blocks.0.norm_joint_mid.weight, down_blocks.2.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.bias, up_blocks.1.attentions.1.transformer_blocks.0.attn_joint_mid.to_q.weight, up_blocks.1.attentions.1.transformer_blocks.0.norm_joint_mid.weight, up_blocks.1.attentions.0.transformer_blocks.0.norm_joint_mid.weight, down_blocks.2.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.weight, down_blocks.2.attentions.1.transformer_blocks.0.attn_joint_mid.to_k.weight, up_blocks.2.attentions.2.transformer_blocks.0.norm_joint_mid.weight, up_blocks.3.attentions.1.transformer_blocks.0.norm_joint_mid.bias, down_blocks.1.attentions.0.transformer_blocks.0.norm_joint_mid.bias, down_blocks.0.attentions.1.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.1.attentions.2.transformer_blocks.0.attn_joint_mid.to_k.weight, down_blocks.1.attentions.1.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.2.attentions.1.transformer_blocks.0.attn_joint_mid.to_q.weight, mid_block.attentions.0.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.2.attentions.1.transformer_blocks.0.norm_joint_mid.bias, up_blocks.2.attentions.2.transformer_blocks.0.norm_joint_mid.bias, up_blocks.3.attentions.2.transformer_blocks.0.attn_joint_mid.to_k.weight, down_blocks.0.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.weight, down_blocks.1.attentions.0.transformer_blocks.0.norm_joint_mid.weight, up_blocks.3.attentions.2.transformer_blocks.0.norm_joint_mid.bias, up_blocks.2.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.weight, up_blocks.3.attentions.1.transformer_blocks.0.attn_joint_mid.to_q.weight, down_blocks.1.attentions.1.transformer_blocks.0.attn_joint_mid.to_k.weight, down_blocks.2.attentions.1.transformer_blocks.0.norm_joint_mid.bias, up_blocks.3.attentions.2.transformer_blocks.0.attn_joint_mid.to_v.weight, down_blocks.2.attentions.0.transformer_blocks.0.attn_joint_mid.to_q.weight, up_blocks.3.attentions.0.transformer_blocks.0.norm_joint_mid.bias, down_blocks.1.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.bias, up_blocks.3.attentions.2.transformer_blocks.0.attn_joint_mid.to_out.0.bias, up_blocks.1.attentions.0.transformer_blocks.0.attn_joint_mid.to_q.weight, up_blocks.1.attentions.0.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.2.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.bias, mid_block.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.weight, up_blocks.1.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.weight, down_blocks.0.attentions.1.transformer_blocks.0.attn_joint_mid.to_k.weight, down_blocks.0.attentions.0.transformer_blocks.0.norm_joint_mid.bias, up_blocks.2.attentions.2.transformer_blocks.0.attn_joint_mid.to_out.0.weight, mid_block.attentions.0.transformer_blocks.0.norm_joint_mid.weight, up_blocks.1.attentions.2.transformer_blocks.0.norm_joint_mid.weight, up_blocks.1.attentions.2.transformer_blocks.0.attn_joint_mid.to_out.0.bias, up_blocks.2.attentions.1.transformer_blocks.0.norm_joint_mid.weight, up_blocks.3.attentions.0.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.1.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.bias, down_blocks.2.attentions.0.transformer_blocks.0.attn_joint_mid.to_k.weight, down_blocks.2.attentions.1.transformer_blocks.0.attn_joint_mid.to_v.weight, down_blocks.1.attentions.1.transformer_blocks.0.norm_joint_mid.weight, mid_block.attentions.0.transformer_blocks.0.attn_joint_mid.to_k.weight, up_blocks.3.attentions.0.transformer_blocks.0.attn_joint_mid.to_k.weight, up_blocks.3.attentions.2.transformer_blocks.0.norm_joint_mid.weight, up_blocks.1.attentions.1.transformer_blocks.0.attn_joint_mid.to_k.weight, down_blocks.2.attentions.0.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.1.attentions.2.transformer_blocks.0.attn_joint_mid.to_out.0.weight, up_blocks.3.attentions.0.transformer_blocks.0.attn_joint_mid.to_q.weight, down_blocks.0.attentions.0.transformer_blocks.0.attn_joint_mid.to_q.weight, up_blocks.3.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.bias, down_blocks.1.attentions.1.transformer_blocks.0.attn_joint_mid.to_q.weight, up_blocks.1.attentions.2.transformer_blocks.0.attn_joint_mid.to_v.weight. Please make sure to pass low_cpu_mem_usage=Falseanddevice_map=None if you want to randomly initialize those weights or else make sure your checkpoint file is correct.

According to my understanding, the missing weights should be trained in stage 2, so they cann't be loaded in before the stage 2's training. Do you have any ideas about this?

RenieWell avatar May 31 '24 02:05 RenieWell

similar problem!

SunzeY avatar Jun 08 '24 20:06 SunzeY

same problem. @RenieWell @SunzeY @xxlong0 Have you solved this?

bbbbubble avatar Jun 24 '24 05:06 bbbbubble

If add low_cpu_mem_usage=False to from_pretrained() function, or use from_pretrained_2d() function instead, it will start running successfully, but will be hard to converge...

image

bbbbubble avatar Jun 27 '24 11:06 bbbbubble

Hi, I am runing in to the same problem. Have you solved this? @bbbbubble @SunzeY @RenieWell

yyuezhi avatar Aug 22 '24 22:08 yyuezhi

maybe try replacing "from_pretrained" to "from_pretrained_2d" in https://github.com/xxlong0/Wonder3D/blob/deeba9833570fce09dd4da393f6318475e85a735/train_mvdiffusion_joint.py#L251 ?

mengxuyiGit avatar Aug 28 '24 14:08 mengxuyiGit

@bbbbubble @mengxuyiGit @RenieWell @SunzeY

Hello, if the problem persists, maybe you can try my solution: change the cd_attention_mid attribute in File "./configs/train/stage1-mix-6views-lvis.yaml " from false to true,

and change line 237 in ./train_mvdiffusion_image.py to

unet = UNetMV2DConditionModel.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="unet", revision=cfg.revision, **cfg.unet_from_pretrained_kwargs).

It works for me. Seems that it is not a problem with the second stage but the first. Pls inform me if anything goes up

liuyifan22 avatar Oct 14 '24 13:10 liuyifan22

@bbbbubble @mengxuyiGit @RenieWell @SunzeY

Hello, if the problem persists, maybe you can try my solution: change the cd_attention_mid attribute in File "./configs/train/stage1-mix-6views-lvis.yaml " from false to true,

and change line 237 in ./train_mvdiffusion_image.py to

unet = UNetMV2DConditionModel.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="unet", revision=cfg.revision, **cfg.unet_from_pretrained_kwargs).

It works for me. Seems that it is not a problem with the second stage but the first. Pls inform me if anything goes up

Seems not work. "Missing Keys" error will show up in the first stage:

[rank0]: ValueError: Cannot load <class 'mvdiffusion.models.unet_mv2d_condition.UNetMV2DConditionModel'> from /home/azhe.cp/avatar_utils/sd-image-variations-diffusers because the following keys are missing: [rank0]: mid_block.attentions.0.transformer_blocks.0.attn_joint_mid.to_q.weight, mid_block.attentions.0.transformer_blocks.0.attn_joint_mid.to_v.weight, mid_block.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.weight, mid_block.attentions.0.transformer_blocks.0.norm_joint_mid.bias, mid_block.attentions.0.transformer_blocks.0.attn_joint_mid.to_k.weight, mid_block.attentions.0.transformer_blocks.0.norm_joint_mid.weight, mid_block.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.bias. [rank0]: Please make sure to pass low_cpu_mem_usage=False and device_map=None if you want to randomly initialize those weights or else make sure your checkpoint file is correct.

bbbbubble avatar Oct 15 '24 11:10 bbbbubble

Sorry, I was using Wonder3D's checkpoint as input for my 1st stage training, and that is fit. If you want to train from lambdalabs/sd-image-variations-diffusers, my method will not be working. Sorry for the confusion.

liuyifan22 avatar Oct 15 '24 12:10 liuyifan22