lora icon indicating copy to clipboard operation
lora copied to clipboard

Gradient checkpointing blocks LoRA weight updates

Open brian6091 opened this issue 1 year ago • 14 comments

Running with gradient checkpointing prevents LoRA weight updates.

Description: Ubuntu 18.04.6 LTS diffusers==0.10.2 lora-diffusion==0.0.3 torchvision @ https://download.pytorch.org/whl/cu116/torchvision-0.14.0%2Bcu116-cp38-cp38-linux_x86_64.whl transformers==4.25.1 xformers @ https://github.com/brian6091/xformers-wheels/releases/download/0.0.15.dev0%2B4c06c79/xformers-0.0.15.dev0+4c06c79.d20221205-cp38-cp38-linux_x86_64.whl

Accelerate version: 0.15.0 Platform: Linux-5.10.133+-x86_64-with-glibc2.27 Python version: 3.8.16 Numpy version: 1.21.6 PyTorch version (GPU?): 1.13.0+cu116 (True)

With gradient checkpointing enabled,

!accelerate launch
--mixed_precision="fp16"
lora/train_lora_dreambooth.py
--pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5"
--instance_data_dir="$INSTANCE_DIR"
--output_dir="$OUTPUT_DIR"
--instance_prompt="$INSTANCE_PROMPT"
--train_text_encoder
--resolution=512
--use_8bit_adam
--seed=1234
--mixed_precision="fp16"
--train_batch_size=4
--gradient_accumulation_steps=1
--gradient_checkpointing
--learning_rate=1e-4
--lr_scheduler="constant"

produces

Before training: Unet First Layer lora up tensor([[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], ..., [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]]) Before training: Unet First Layer lora down tensor([[-0.0125, 0.0331, 0.0198, ..., 0.0715, 0.0393, -0.1777], [-0.0442, 0.0572, 0.0026, ..., 0.0876, 0.0085, 0.0050], [ 0.0410, -0.0777, 0.0313, ..., -0.0613, -0.0111, -0.0451], [ 0.0202, -0.0079, 0.1156, ..., -0.0167, 0.0915, 0.0737]]) Before training: text encoder First Layer lora up tensor([[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], ..., [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]]) Before training: text encoder First Layer lora down tensor([[ 0.0175, 0.0458, 0.1019, ..., -0.1470, 0.1538, 0.0120], [-0.0307, -0.1303, 0.0911, ..., 0.0317, 0.0829, 0.0084], [-0.0016, 0.1495, -0.1105, ..., -0.0781, 0.0122, 0.0272], [ 0.0182, -0.0064, -0.0268, ..., 0.0800, 0.0745, 0.0231]])

and after some iterations

First Unet Layer's Up Weight is now : tensor([[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], ..., [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]], device='cuda:0') First Unet Layer's Down Weight is now : tensor([[-0.0125, 0.0331, 0.0198, ..., 0.0715, 0.0393, -0.1777], [-0.0442, 0.0572, 0.0026, ..., 0.0876, 0.0085, 0.0050], [ 0.0410, -0.0777, 0.0313, ..., -0.0613, -0.0111, -0.0451], [ 0.0202, -0.0079, 0.1156, ..., -0.0167, 0.0915, 0.0737]], device='cuda:0') First Text Encoder Layer's Up Weight is now : tensor([[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], ..., [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]], device='cuda:0') First Text Encoder Layer's Down Weight is now : tensor([[ 0.0175, 0.0458, 0.1019, ..., -0.1470, 0.1538, 0.0120], [-0.0307, -0.1303, 0.0911, ..., 0.0317, 0.0829, 0.0084], [-0.0016, 0.1495, -0.1105, ..., -0.0781, 0.0122, 0.0272], [ 0.0182, -0.0064, -0.0268, ..., 0.0800, 0.0745, 0.0231]], device='cuda:0')

Disabling gradient checkpointing seems to work fine

Before training: Unet First Layer lora up tensor([[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], ..., [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]]) Before training: Unet First Layer lora down tensor([[-0.0125, 0.0331, 0.0198, ..., 0.0715, 0.0393, -0.1777], [-0.0442, 0.0572, 0.0026, ..., 0.0876, 0.0085, 0.0050], [ 0.0410, -0.0777, 0.0313, ..., -0.0613, -0.0111, -0.0451], [ 0.0202, -0.0079, 0.1156, ..., -0.0167, 0.0915, 0.0737]]) Before training: text encoder First Layer lora up tensor([[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], ..., [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]]) Before training: text encoder First Layer lora down tensor([[ 0.0175, 0.0458, 0.1019, ..., -0.1470, 0.1538, 0.0120], [-0.0307, -0.1303, 0.0911, ..., 0.0317, 0.0829, 0.0084], [-0.0016, 0.1495, -0.1105, ..., -0.0781, 0.0122, 0.0272], [ 0.0182, -0.0064, -0.0268, ..., 0.0800, 0.0745, 0.0231]])

and after some iterations

First Unet Layer's Up Weight is now : tensor([[-2.6580e-03, 7.4147e-04, 1.7193e-03, 1.7760e-04], [ 1.1531e-03, 3.4420e-04, 1.6359e-03, -2.6158e-05], [-2.0090e-04, 9.7763e-04, 9.0458e-04, -1.2152e-03], ..., [ 1.3022e-03, -1.6245e-03, 1.3225e-03, -2.2149e-03], [-4.9904e-04, 7.6633e-04, -1.1046e-03, 8.2197e-04], [ 2.1200e-03, -7.4285e-04, -2.7083e-03, 7.7677e-04]], device='cuda:0') First Unet Layer's Down Weight is now : tensor([[-0.0120, 0.0336, 0.0196, ..., 0.0711, 0.0415, -0.1778], [-0.0446, 0.0560, 0.0023, ..., 0.0877, 0.0085, 0.0037], [ 0.0402, -0.0783, 0.0297, ..., -0.0606, -0.0117, -0.0448], [ 0.0169, -0.0067, 0.1153, ..., -0.0172, 0.0923, 0.0734]], device='cuda:0') First Text Encoder Layer's Up Weight is now : tensor([[ 2.7144e-05, 4.5192e-05, -4.2374e-05, 3.5689e-05], [ 1.5236e-04, 2.1131e-04, 2.3639e-04, 1.8105e-04], [ 1.6095e-04, -1.8110e-04, -6.2436e-05, 1.2356e-04], ..., [ 1.3739e-04, -1.1521e-04, -1.0960e-04, 1.2269e-04], [ 1.6732e-05, -1.3146e-05, -2.5539e-04, 1.7016e-04], [ 2.5715e-04, -3.0459e-04, -1.9317e-04, -2.3927e-04]], device='cuda:0') First Text Encoder Layer's Down Weight is now : tensor([[ 0.0172, 0.0456, 0.1016, ..., -0.1473, 0.1535, 0.0117], [-0.0304, -0.1304, 0.0914, ..., 0.0319, 0.0832, 0.0087], [-0.0014, 0.1495, -0.1103, ..., -0.0783, 0.0124, 0.0275], [ 0.0183, -0.0065, -0.0266, ..., 0.0801, 0.0748, 0.0234]], device='cuda:0')

brian6091 avatar Dec 13 '22 15:12 brian6091

In my case generated text weights do not affect image generation at all... But regular lora models do work. Maybe that is the same issue as this one, because I also use gradient_checkpointing.

qunash avatar Dec 13 '22 21:12 qunash

Same here. I think it's related to this warning I'm getting:

torch/utils/checkpoint.py:25: UserWarning: None of the inputs have requires_grad=True. Gradients will be None
  warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")

matteoserva avatar Dec 13 '22 21:12 matteoserva

Yes grad checkpoint doesnt work rn sorry...

cloneofsimo avatar Dec 13 '22 22:12 cloneofsimo

I can confirm what @qunash found. With gradient checkpointing enabled, the lora weights for both the unet and the text encoder will not change when printed to screen. However, in the saved .pt files, the lora weights for the unet are actually updated, whereas those for the text encoder are not (all require_grad=False, which produces the warning @matteoserva shows).

The method for enabling gradient checkpointing is different for the unet (enable_gradient_checkpointing()) and the text encoder (gradient_checkpointing_enable()), so maybe there is a clue there...

brian6091 avatar Dec 13 '22 22:12 brian6091

This would explain why my final trained model always gives me the exact same images as the original stable diffusion v2.1 768px model. Since my RTX 4080 is limited to 16GB, please make gradient checkpointing work. When I disable it, I get CUDA OOM errors. Training dreambooth in the original diffusers repo does work fine for me with fp16 and gradient_checkpointing.

djdookie avatar Dec 14 '22 15:12 djdookie

It seems that at least one of the inputs must have "requires_grad=True" for the torch.utils.checkpoint to work.

A simple solution to this problem for UNet-only training is to set unet.conv_in.requires_grad_(True)

laksjdjf avatar Jan 19 '23 02:01 laksjdjf

I face the same problem when opening gradient checkpointing, Is there a way to solve this problem under text-encoder and unet joint training?

kingofprank avatar Feb 08 '23 07:02 kingofprank

@kingofprank for joint training, I've enabling gradient checkpointing for the unet only, and just not enabling it for the text encoder. This works, and I think you get most of the benefit since the unet has so many more parameters than the text encoder.

brian6091 avatar Feb 08 '23 07:02 brian6091

@brian6091 thanks for your advice, but I only have 11gb 2080Ti, I will encounter OOM problem when disabling gradient checkpointing for text-encoder.

kingofprank avatar Feb 08 '23 08:02 kingofprank

I face the same problem when opening gradient checkpointing, Is there a way to solve this problem under text-encoder and unet joint training?

Kohya's repo seems to have solved it this way https://github.com/kohya-ss/sd-scripts/commit/e6a8c9d269b4952a6944dfe0e78a1f89bd036971

laksjdjf avatar Feb 08 '23 09:02 laksjdjf

I face the same problem when opening gradient checkpointing, Is there a way to solve this problem under text-encoder and unet joint training?

Kohya's repo seems to have solved it this way kohya-ss/sd-scripts@e6a8c9d

Interesting, but what if you didn't want to train the embeddings?

brian6091 avatar Feb 08 '23 09:02 brian6091

Interesting, but what if you didn't want to train the embeddings?

In the case of LoRA, the embeddings parameter is not passed to the optimizer. Therefore, it is not trained.

laksjdjf avatar Feb 08 '23 09:02 laksjdjf

Ah, nice trick. Thanks

brian6091 avatar Feb 08 '23 10:02 brian6091

@laksjdjf It works, thx~

kingofprank avatar Feb 08 '23 11:02 kingofprank