lora icon indicating copy to clipboard operation
lora copied to clipboard

Text encoder still not working correctly with LoRa Dreambooth training script

Open JohnnyRacer opened this issue 2 years ago • 28 comments

Hello, I am getting much better results using the --train_text_encoder flag with the Dreambooth script. However, the actual outputed LoRa .pt files from models trained with train_text_encoder gives very bad results after using monkeypatch to generate images. I suspect that the text encoder's weights are still not saved properly. I tried to save the pipeline directly after each epoch from within the training script, but loading it using diffusers gives me strange errors about torch not being able to parse the linear layers. Does anyone have similar experiences with training the text encoder or have any idea why this is happening?

Images sampled from within the training loop (train_text_encoder enabled) : 2 2e 3

Images sampled after model was monkeypatch with the trained LoRa weights (train_text_encoder enabled) : bad1 bad2 bad3

The images doesn't seem to correlate with the samples generated while training and has very little cohesiveness with the training images used.

JohnnyRacer avatar Dec 13 '22 21:12 JohnnyRacer

Did you also patch the text encoder as well?

cloneofsimo avatar Dec 13 '22 22:12 cloneofsimo

Have a look at: https://github.com/cloneofsimo/lora/blob/master/scripts/run_with_text_lora_also.ipynb

cloneofsimo avatar Dec 13 '22 22:12 cloneofsimo

Did you also patch the text encoder as well?

@cloneofsimo

Seems like the repo was corrupted when I cloned it. I recloned and retrained the model and the text encoder files are there.

JohnnyRacer avatar Dec 14 '22 01:12 JohnnyRacer

@cloneofsimo Got the text encoder patching to work , but it seems like it overfits very easily and gives quite bad results at times compared to training Dreambooth without LoRa. And even stranger is that there are a lot more artifacts after using the text encoder patch like watermarks from the training images that are showing up that had not appeared in the training like the examples I had show earlier like the examples below. Any clue why this is occuring? overfit overfit1

Update :

Using tune_lora_scale(pipe.text_encoder, 0.7) to lower the strength of the text encoder weights seem to help a little, but also reduces image quality overall.

JohnnyRacer avatar Dec 14 '22 02:12 JohnnyRacer

Nice to see that it works. I've found that when it overfits, tuning down both unet and text encoder helps, but with different scale. Also try learning it with prior preservation. Seems like you are dreamboothing with faces, and guide like https://github.com/nitrosocke/dreambooth-training-guide from @nitrosocke might help. (Also says to crop out any artifacts)

cloneofsimo avatar Dec 14 '22 05:12 cloneofsimo

https://github.com/cloneofsimo/lora#what-happens-to-text-encoder-lora-and-unet-lora Have a look at this part as well

cloneofsimo avatar Dec 14 '22 05:12 cloneofsimo

I've posted some images of the effect of tuning both the unet and the text encoder (with prior preservation) in the discussions section:

https://github.com/cloneofsimo/lora/discussions/37

brian6091 avatar Dec 14 '22 15:12 brian6091

I've posted some images of the effect of tuning both the unet and the text encoder (with prior preservation) in the discussions section:

#37

Thanks for the amazing guide @brian6091 ! Have you experimented with using a class dataset to act as a regularization similar to what was mentioned in the original Dreambooth paper? I tried to use it but it gives me strange errors about accelerate.launch() not being instantiated. I think it could give a noticeable increase image quality compared to using pipeline itself to automatically generate class images.

JohnnyRacer avatar Dec 14 '22 21:12 JohnnyRacer

Nice to see that it works. I've found that when it overfits, tuning down both unet and text encoder helps, but with different scale. Also try learning it with prior preservation. Seems like you are dreamboothing with faces, and guide like https://github.com/nitrosocke/dreambooth-training-guide from @nitrosocke might help. (Also says to crop out any artifacts)

From my experience so far it seems like training with LoRA requires much cleaner and specific datasets to work (i.e. aligned faces) , trying to train on more diverse datasets (i.e. pictures of a person in different poses) often leads to bad results.

JohnnyRacer avatar Dec 14 '22 22:12 JohnnyRacer

I've posted some images of the effect of tuning both the unet and the text encoder (with prior preservation) in the discussions section: #37

Thanks for the amazing guide @brian6091 ! Have you experimented with using a class dataset to act as a regularization similar to what was mentioned in the original Dreambooth paper? I tried to use it but it gives me strange errors about accelerate.launch() not being instantiated. I think it could give a noticeable increase image quality compared to using pipeline itself to automatically generate class images.

I used a class dataset for the training I showed in the discussion post. It was the same for both Dreambooth and LoRA. I used the SD pipeline to generate the class images. I didn't get time to show the effects of both Dreambooth and LoRA on the class, but I may add some to a future discussion.

I was under the impression that the Dreambooth method was to use SD to generate class images. What did you do to get the error? If you post the error, maybe we could help debug it.

brian6091 avatar Dec 15 '22 06:12 brian6091

@brian6091

I am getting this error if I pass in a directory with --class_data_dir :

ValueError: Please make sure to properly initialize your accelerator via 'accelerator = Accelerator()' before using any functionality from the 'accelerate'  library.

Without adding this argument it trains fine. I've tried using both accelerate launch as well as just using python to start the script but it gives the same errors. This is on the latest commit to the repo from today. Any ideas on why this is happening?

JohnnyRacer avatar Dec 15 '22 21:12 JohnnyRacer

@JohnnyRacer Are you running in Colab? Which script or notebook are you using. I think I've seen this kind of error when there is a malformed input.

brian6091 avatar Dec 15 '22 21:12 brian6091

@brian6091

No I'm running it locally. I am using the train_lora_dreambooth.py script from this repo with no modifications. The ./aradesh_train_col is where I have my training images and ./aradesh_reg_col is where the regularization images are located.

The command below works

 accelerate launch train_lora_dreambooth.py \
  --pretrained_model_name_or_path="../sd-v1-5-diffusers" \
  --instance_data_dir="./aradesh_train_col" \
  --output_dir="./aradesh-sd1.5-textenc" \
  --instance_prompt="sks aradesh" \
  --resolution=512 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=1 \
  --learning_rate=1e-4 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=600 \
  --use_8bit_adam \
  --train_text_encoder

The command below doesn't work and gives the accelerate launch error that I mentioned earlier.

 accelerate launch train_lora_dreambooth.py \
  --pretrained_model_name_or_path="../sd-v1-5-diffusers" \
  --instance_data_dir="./aradesh_train_col" \
  --class_data_dir="./aradesh_reg_col" \
  --output_dir="./aradesh-sd1.5-textenc" \
  --instance_prompt="sks aradesh" \
  --resolution=512 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=1 \
  --learning_rate=1e-4 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=600 \
  --use_8bit_adam \
  --train_text_encoder

JohnnyRacer avatar Dec 15 '22 23:12 JohnnyRacer

@JohnnyRacer

I think adding this flag would work:

--with_prior_preservation

I would also add this:

--prior_loss_weight=1.0

since the default weight is oddly huge

brian6091 avatar Dec 15 '22 23:12 brian6091

@brian6091

Thanks for the help, I got it to work but also had to include the --class_prompt argument as well. Did you use the same prompt for the --class_prompt as the --instance_prompt when you were training your models?

JohnnyRacer avatar Dec 15 '22 23:12 JohnnyRacer

@brian6091

Also it seems like using the class_prompt argument made the VRAM usage increase about 50% (from ~12gb to 18gb), is this to be expected?

JohnnyRacer avatar Dec 15 '22 23:12 JohnnyRacer

@JohnnyRacer So the class prompt should represent the class you don't want your model to forget. For example, if you are training for a specific person,

instance_prompt = "raretoken person" class prompt = "person"

I've never trained without prior preservation, so I don't know about the memory increase. If you are running into memory limitations, have a look at the table in the readme here: https://github.com/ShivamShrirao/diffusers/tree/main/examples/dreambooth#readme

brian6091 avatar Dec 15 '22 23:12 brian6091

@brian6091

Ah okay, I was using the same prompt for both earlier and I wasn't getting much of a difference. Is it possible to use gradient checkpointing now for training to reduce VRAM? In an earlier issue I saw that that it was not saving properly.

JohnnyRacer avatar Dec 16 '22 00:12 JohnnyRacer

Ah sorry, I forgot about that. I don't think gradient checkpointing is working yet (if you want to train the text encoder).

brian6091 avatar Dec 16 '22 00:12 brian6091

What is a good way to check if the results have been trained with good (enough) parameters? So far the likeness is hit and miss, usually miss, thats on 5 images at 5000 steps.

G-force78 avatar Dec 16 '22 13:12 G-force78

You can sample a few prompts at intermediate checkpoints, say, every 500 or 1000 iterations. You can also track the loss, although that is very noisy.

brian6091 avatar Dec 16 '22 13:12 brian6091

Ah right I see, just change the path here? monkeypatch_lora(pipe.unet, torch.load(os.path.join(OUTPUT_DIR, "lora_weight.pt")))

G-force78 avatar Dec 16 '22 13:12 G-force78

@G-force78 Not sure, depends on what you're using to run the training. But basically, around the part where you save a checkpoint, you need to:

  1. construct an inference pipeline
  2. monkeypatch it with your weights
  3. generate some images

Example here: https://github.com/brian6091/Dreambooth/blob/main/train.py#L934-1006

brian6091 avatar Dec 16 '22 15:12 brian6091

@brian6091 Just a quick thought on improving the image quality, I was wondering if there's a way to use generated text captions to enhance the training of the text encoder similar to how embeddings are trained for Automatic1111's repo. The majority of the trained models seem to have a hard time focusing on generating the detail that's desired and is a lot more susceptible to picking up unwanted details, by conditioning the text model on specific prompts that are matched with the image it should be better at generating more favorable samples.

JohnnyRacer avatar Dec 16 '22 21:12 JohnnyRacer

@G-force78 Not sure, depends on what you're using to run the training. But basically, around the part where you save a checkpoint, you need to:

1. construct an inference pipeline

2. monkeypatch it with your weights

3. generate some images

Example here: https://github.com/brian6091/Dreambooth/blob/main/train.py#L934-1006

Im just using the collab

#@title LOADING MODEL AND MONKEY PATCHING IT
import torch
from lora_diffusion import monkeypatch_lora, tune_lora_scale
from diffusers import StableDiffusionPipeline


pipe = StableDiffusionPipeline.from_pretrained(PRETRAINED_MODEL, torch_dtype=torch.float16).to("cuda")
monkeypatch_lora(pipe.unet, torch.load(os.path.join(OUTPUT_DIR, "lora_weight.pt")))
monkeypatch_lora(pipe.text_encoder, torch.load(os.path.join(OUTPUT_DIR, "lora_weight.text_encoder.pt")), target_replace_module=["CLIPAttention"])

G-force78 avatar Dec 16 '22 22:12 G-force78

@G-force78

Are you having problems or is something not working? That is the correct way to run it.

JohnnyRacer avatar Dec 16 '22 23:12 JohnnyRacer

@G-force78

Looks right, now just run the pipe with a prompt

brian6091 avatar Dec 17 '22 08:12 brian6091

@G-force78

Are you having problems or is something not working? That is the correct way to run it.

This was because I used 768 images in a 512 model.

I think Ive maybe chosen a used token within the sd 2.1 base model or something as the images are nothing like my training images..5000 training steps, 5 images, prompt bobo bobotest

Training images example Bob Odenkirk

800px-Bob_Odenkirk_by_Gage_Skidmore_2

G-force78 avatar Dec 17 '22 11:12 G-force78