Diffusers LoRA inference code does not work.
Hi,
I LoRA fine-tuned T2V according to https://github.com/THUDM/CogVideo/tree/main/finetune and I am trying to run inference on the fine-tuned model through the following command.
python cli_demo.py --prompt "a girl riding a bike." --model_path THUDM/CogVideoX-5b --generate_type "t2v" --lora_path MyPath/checkpoint-10000
However, the output looks like there is no LoRA fine-tuning.
- The validation output video (during training) does show that the output changes according to the finetuning dataset.
- But the inference code by
python cli_demo.pyresults in videos that look like it is from the pretrained untuned model.
第85行,pipe.fuse_lora(lora_scale=1 / lora_rank)改成pipe.fuse_lora(lora_scale=1 , components=['transformer'])试试,我是这样跑出效果的
第85行,
pipe.fuse_lora(lora_scale=1 / lora_rank)改成pipe.fuse_lora(lora_scale=1 , components=['transformer'])试试,我是这样跑出效果的
lora注入到视频生成,控制什么呢?效果怎样
第85行,
pipe.fuse_lora(lora_scale=1 / lora_rank)改成pipe.fuse_lora(lora_scale=1 , components=['transformer'])试试,我是这样跑出效果的lora注入到视频生成,控制什么呢?效果怎样
控制人物形象,5b效果挺好的,然后风格训练不了,要6000个视频,内存2t直接爆了,人物200个视频就可以训练好了
https://github.com/user-attachments/assets/86d46389-c123-4e64-8976-cff1d4e2ceca
下面是参考人物
https://github.com/user-attachments/assets/a889f39b-1711-4c23-bd1d-0a40c39afbcf
厉害的demo,这和图片lora原理一样?人物200个4s的视频,加上描述,训练人物的视频lora?
Hi,
I LoRA fine-tuned T2V according to https://github.com/THUDM/CogVideo/tree/main/finetune and I am trying to run inference on the fine-tuned model through the following command.
python cli_demo.py --prompt "a girl riding a bike." --model_path THUDM/CogVideoX-5b --generate_type "t2v" --lora_path MyPath/checkpoint-10000However, the output looks like there is no LoRA fine-tuning.
- The validation output video (during training) does show that the output changes according to the finetuning dataset.
- But the inference code by
python cli_demo.pyresults in videos that look like it is from the pretrained untuned model.
When I use the lora weights to inference by python cli_demo.py, some error happens.
ValueError: Target modules {'base_model.model.transformer_blocks.22.attn1.to_v', 'base_model.model.transformer_blocks.0.attn1.to_v', 'base_model.model.transformer_blocks.9.attn1.to_v', 'base_model.model.transformer_blocks.26.attn1.to_q', 'base_model.model.transformer_blocks.10.attn1.to_q', 'base_model.model.transformer_blocks.18.attn1.to_v', 'base_model.model.transformer_blocks.19.attn1.to_q', 'base_model.model.transformer_blocks.18.attn1.to_q', 'base_model.model.transformer_blocks.25.attn1.to_q', 'base_model.model.transformer_blocks.19.attn1.to_v', 'base_model.model.transformer_blocks.6.attn1.to_v', 'base_model.model.transformer_blocks.22.attn1.to_out.0', 'base_model.model.transformer_blocks.0.attn1.to_k', 'base_model.model.transformer_blocks.14.attn1.to_k', 'base_model.model.transformer_blocks.29.attn1.to_q', 'base_model.model.transformer_blocks.27.attn1.to_k', 'base_model.model.transformer_blocks.11.attn1.to_q', 'base_model.model.transformer_blocks.2.attn1.to_out.0', 'base_model.model.transformer_blocks.10.attn1.to_out.0', 'base_model.model.transformer_blocks.26.attn1.to_out.0', 'base_model.model.transformer_blocks.5.attn1.to_v', 'base_model.model.transformer_blocks.9.attn1.to_q', 'base_model.model.transformer_blocks.6.attn1.to_q', 'base_model.model.transformer_blocks.26.attn1.to_v', 'base_model.model.transformer_blocks.15.attn1.to_out.0', 'base_model.model.transformer_blocks.25.attn1.to_v', 'base_model.model.transformer_blocks.24.attn1.to_v', 'base_model.model.transformer_blocks.9.attn1.to_k', 'base_model.model.transformer_blocks.23.attn1.to_k', 'base_model.model.transformer_blocks.9.attn1.to_out.0', 'base_model.model.transformer_blocks.3.attn1.to_q', 'base_model.model.transformer_blocks.21.attn1.to_v', 'base_model.model.transformer_blocks.2.attn1.to_k', 'base_model.model.transformer_blocks.12.attn1.to_out.0', 'base_model.model.transformer_blocks.4.attn1.to_v', 'base_model.model.transformer_blocks.28.attn1.to_v', 'base_model.model.transformer_blocks.27.attn1.to_q', 'base_model.model.transformer_blocks.29.attn1.to_k', 'base_model.model.transformer_blocks.13.attn1.to_v', 'base_model.model.transformer_blocks.27.attn1.to_out.0', 'base_model.model.transformer_blocks.12.attn1.to_v', 'base_model.model.transformer_blocks.21.attn1.to_q', 'base_model.model.transformer_blocks.29.attn1.to_out.0', 'base_model.model.transformer_blocks.13.attn1.to_q', 'base_model.model.transformer_blocks.22.attn1.to_q', 'base_model.model.transformer_blocks.0.attn1.to_q', 'base_model.model.transformer_blocks.8.attn1.to_v', 'base_model.model.transformer_blocks.11.attn1.to_k', 'base_model.model.transformer_blocks.26.attn1.to_k', 'base_model.model.transformer_blocks.28.attn1.to_q', 'base_model.model.transformer_blocks.22.attn1.to_k', 'base_model.model.transformer_blocks.11.attn1.to_v', 'base_model.model.transformer_blocks.14.attn1.to_v', 'base_model.model.transformer_blocks.16.attn1.to_k', 'base_model.model.transformer_blocks.24.attn1.to_k', 'base_model.model.transformer_blocks.28.attn1.to_k', 'base_model.model.transformer_blocks.10.attn1.to_k', 'base_model.model.transformer_blocks.8.attn1.to_k', 'base_model.model.transformer_blocks.15.attn1.to_q', 'base_model.model.transformer_blocks.16.attn1.to_out.0', 'base_model.model.transformer_blocks.2.attn1.to_q', 'base_model.model.transformer_blocks.5.attn1.to_q', 'base_model.model.transformer_blocks.19.attn1.to_out.0', 'base_model.model.transformer_blocks.27.attn1.to_v', 'base_model.model.transformer_blocks.7.attn1.to_k', 'base_model.model.transformer_blocks.7.attn1.to_out.0', 'base_model.model.transformer_blocks.2.attn1.to_v', 'base_model.model.transformer_blocks.6.attn1.to_k', 'base_model.model.transformer_blocks.21.attn1.to_k', 'base_model.model.transformer_blocks.15.attn1.to_k', 'base_model.model.transformer_blocks.13.attn1.to_k', 'base_model.model.transformer_blocks.18.attn1.to_k', 'base_model.model.transformer_blocks.21.attn1.to_out.0', 'base_model.model.transformer_blocks.23.attn1.to_q', 'base_model.model.transformer_blocks.23.attn1.to_v', 'base_model.model.transformer_blocks.20.attn1.to_v', 'base_model.model.transformer_blocks.4.attn1.to_q', 'base_model.model.transformer_blocks.3.attn1.to_k', 'base_model.model.transformer_blocks.20.attn1.to_q', 'base_model.model.transformer_blocks.17.attn1.to_q', 'base_model.model.transformer_blocks.25.attn1.to_out.0', 'base_model.model.transformer_blocks.23.attn1.to_out.0', 'base_model.model.transformer_blocks.17.attn1.to_v', 'base_model.model.transformer_blocks.1.attn1.to_v', 'base_model.model.transformer_blocks.20.attn1.to_k', 'base_model.model.transformer_blocks.14.attn1.to_q', 'base_model.model.transformer_blocks.1.attn1.to_out.0', 'base_model.model.transformer_blocks.20.attn1.to_out.0', 'base_model.model.transformer_blocks.7.attn1.to_v', 'base_model.model.transformer_blocks.5.attn1.to_k', 'base_model.model.transformer_blocks.16.attn1.to_q', 'base_model.model.transformer_blocks.1.attn1.to_q', 'base_model.model.transformer_blocks.4.attn1.to_k', 'base_model.model.transformer_blocks.8.attn1.to_q', 'base_model.model.transformer_blocks.29.attn1.to_v', 'base_model.model.transformer_blocks.1.attn1.to_k', 'base_model.model.transformer_blocks.3.attn1.to_out.0', 'base_model.model.transformer_blocks.7.attn1.to_q', 'base_model.model.transformer_blocks.12.attn1.to_k', 'base_model.model.transformer_blocks.10.attn1.to_v', 'base_model.model.transformer_blocks.3.attn1.to_v', 'base_model.model.transformer_blocks.24.attn1.to_out.0', 'base_model.model.transformer_blocks.17.attn1.to_k', 'base_model.model.transformer_blocks.8.attn1.to_out.0', 'base_model.model.transformer_blocks.16.attn1.to_v', 'base_model.model.transformer_blocks.13.attn1.to_out.0', 'base_model.model.transformer_blocks.5.attn1.to_out.0', 'base_model.model.transformer_blocks.6.attn1.to_out.0', 'base_model.model.transformer_blocks.14.attn1.to_out.0', 'base_model.model.transformer_blocks.11.attn1.to_out.0', 'base_model.model.transformer_blocks.24.attn1.to_q', 'base_model.model.transformer_blocks.28.attn1.to_out.0', 'base_model.model.transformer_blocks.18.attn1.to_out.0', 'base_model.model.transformer_blocks.12.attn1.to_q', 'base_model.model.transformer_blocks.15.attn1.to_v', 'base_model.model.transformer_blocks.19.attn1.to_k', 'base_model.model.transformer_blocks.25.attn1.to_k', 'base_model.model.transformer_blocks.0.attn1.to_out.0', 'base_model.model.transformer_blocks.4.attn1.to_out.0', 'base_model.model.transformer_blocks.17.attn1.to_out.0'} not found in the base model. Please check the target modules and try again.``
I think the lora weights is not been merge to model. But cli_demo.py have the code to load lora
if lora_path: pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="test_1") pipe.fuse_lora(lora_scale=1 / lora_rank)
The command is python inference/cli_demo.py --prompt "a man is walking on the street." --lora_path my-path/checkpoint-10000
Please tell me how to solve it. Thanks!
I met the same issue, I think the lora_scale should be (lora_alpha / lora_rank)? which is indicated during lora finetuning, however, after doing so, the testing results are still different compared with the results shown during the finetuning.
Any clues on this? Thanks!
The SAT framework is somewhat different from the Diffusers framework, and the trained weights need to be inferred separately
If diffuser inference is needed, you need to convert the weight network first, and then load it using the diffuser script
- convert
python export_sat_lora_weight.py --sat_pt_path <model_path>/lora-disney-09-21-21-51/1000/mp_rank_00_model_states.pt --lora_save_directory <model_path>/export_hf_lora_base_009_1000_weights
- infer
python load_cogvideox_lora.py --pretrained_model_name_or_path <model_path>/CogVideoX-2b-base --lora_weights_path <model_path>/export_hf_lora_base_009_1000_weights --lora_r 256 --prompt " "
Thanks for reply @glide-the . I am using diffuser finetuning directly, instead of sat. And I am using cli_demo.py for inference. But the evaluation results are worse than the results shown during the finetuning (i mean the validation in wandb). Could you have a check on it?
BTW, I have revised the lora_scale calculation. For example, I used rank=128, and alpha=64, and I have changed the "pipe.fuse_lora(lora_scale=1 / lora_rank)" line correspondingly, by changing "1" into "64" and "lora_rank=128".
Thanks.
Thanks for reply @glide-the . I am using diffuser finetuning directly, instead of sat. And I am using cli_demo.py for inference. But the evaluation results are worse than the results shown during the finetuning (i mean the validation in wandb). Could you have a check on it?
BTW, I have revised the lora_scale calculation. For example, I used rank=128, and alpha=64, and I have changed the "pipe.fuse_lora(lora_scale=1 / lora_rank)" line correspondingly, by changing "1" into "64" and "lora_rank=128".
Thanks.
The code of cli_demo.py has not been updated yet. clone installs the diffusers repository. Use this pr to solve this problem.
reference pr https://github.com/THUDM/CogVideo/pull/411
Hi @glide-the and @zRzRzRzRzRzRzR ,sry for the late response.
I've tried "load_cogvideox_lora.py" to test the model with lora. It can do the inference. But I still have a question. The evaluation results between (1) during lora training and (2) after lora training and using "load_cogvideox_lora.py" are different.
For example, I finished the training and the code will generate a final test video and a lora checkpoint. If I use this checkpoint to inference using "load_cogvideox_lora.py", the generated video is different from the former one.
I use the same prompt and the same random seed, which should remove any randomness and get two exactly the same videos I think.
Is there any clue on this?
Thanks!
For more information, I also checked the guidance scale, use_dynamic_cfg, dtype, etc, they seem like correct.
Thanks!
For more information, I also checked the guidance scale, use_dynamic_cfg, dtype, etc, they seem like correct.
Thanks!
@justlovebarbecue Hi, I met the same problem, can you tell me how you solved it? Thanks!