sd-scripts icon indicating copy to clipboard operation
sd-scripts copied to clipboard

LearningRate-Free Learning Algorithm

Open BootsofLagrangian opened this issue 2 years ago • 33 comments

Hi, how about D-adaptation?

This is a kind of algorithm that end-user doesn't need to set specific learning rate.

In short, D-adaptation use boundedness to find proper learning rate.

So, it might be useful to someone who hard to find hyperparameters.

Before I wrote this issue, I implement D-adaptation optimizer(Adam) for LoRA. It works!

A few code need to implementation. But I don't know all about sd-scripts code, there exists hard codings.

Requirement for D-dataptation is only torch>=1.5.1 and pip install dadaptation.

Here are codes.

In train_network.py from torch.optim as optim # using for a raw learning rate scheduler import dadaptation and I hard-coded for applying optimizer. optimizer = optimizer_class(trainable_params, lr=args.learning_rate) to optimizer = dadaptation.DAdaptAdam(trainable_params, lr=1.0, decouple=True, weight_decay=1.0) Setting decople=True means that optimizer is AdamW not Adam. and weight_decay is for l2 penalty.

Other argumentation is not for end-user.(maybe)

And trainable_params doesn't need a specific learning rate, so replace trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) to trainable_params = network.prepare_optimizer_params(None, None)

In sd-scripts, lr_scheduler is a return of get_scheduler_fix function.

But I don't know why using get_scheduler_fix interrupt D-adaptation,

so I override lr_scheduler to LambdaLR. sorry for hard coding again :)

lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=[lambda epoch: 1, lambda epoch: 1], last_epoch=-1, verbose=False)

For monitoring dlr value,

logs['lr/d*lr'] = optimizer.param_groups[0]['d']*optimizer.param_groups[0]['lr']

might be needed. All things done.

0c09e55bc847abe5d30bbcf1d03cdb19dd59e7e6760ff0576c66e35d6eae5600

This image is d*lr-step graph when I use D-dadaptation.

I trained LoRA using D-adaptation, result is here.

Thank you!

BootsofLagrangian avatar Feb 11 '23 14:02 BootsofLagrangian

@BootsofLagrangian Would you be able to fork the repo and commit your changes so it would be easier for plebs like me to follow your changes?

AI-Casanova avatar Feb 11 '23 19:02 AI-Casanova

@BootsofLagrangian Would you be able to fork the repo and commit your changes so it would be easier for plebs like me to follow your changes?

Sorry, I'm not familiar with github. It takes long time to make a fork or repo.

and this changes include hard codes, does it matter if there's something like this?

BootsofLagrangian avatar Feb 12 '23 03:02 BootsofLagrangian

When you fork, the code becomes your own, and you can hard code changes into your own copy. But that's ok. Maybe I'll do it, and @ you if I have any problems.

(Also, unless you're on mobile, forking is fast and easy, click fork, click done, and boom)

AI-Casanova avatar Feb 12 '23 06:02 AI-Casanova

When you fork, the code becomes your own, and you can hard code changes into your own copy. But that's ok. Maybe I'll do it, and @ you if I have any problems.

(Also, unless you're on mobile, forking is fast and easy, click fork, click done, and boom)

I made a fork.

I just changed train_network.py and requirements.txt

BootsofLagrangian avatar Feb 12 '23 08:02 BootsofLagrangian

Cool stuff

bmaltais avatar Feb 12 '23 11:02 bmaltais

This should be added as an official feature to the project. Like it!

bmaltais avatar Feb 12 '23 11:02 bmaltais

@BootsofLagrangian I see that both TE LR and UNet LR are no longer specified. Do you know if Dadatation set both to be the same? Do you know if it is possible to set them to different values is it is the same? For LoRA it used that setting TE to a smaller LR than UNet was better. Not sure how this is doing it for each.

bmaltais avatar Feb 12 '23 12:02 bmaltais

@bmaltais wouldn't you have to proc it twice with lr=1.0 for UNet and <1 for TE? Since in essence you have two different training problems going on at once?

From the source repo Set the LR parameter to 1.0. This parameter is not ignored, rather, setting it larger to smaller will directly scale up or down the D-Adapted learning rate. Sounds like 1.0 and 0.5 would match the settings commonly used (1e-4 and 5e-5)

And maybe Dadaptation is most suited for UNet, since under fitting the text encoder is often desirable.

AI-Casanova avatar Feb 12 '23 13:02 AI-Casanova

@BootsofLagrangian you're awesome! Can't wait to play with it!

AI-Casanova avatar Feb 12 '23 13:02 AI-Casanova

@BootsofLagrangian I see that both TE LR and UNet LR are no longer specified. Do you know if Dadatation set both to be the same? Do you know if it is possible to set them to different values is it is the same? For LoRA it used that setting TE to a smaller LR than UNet was better. Not sure how this is doing it for each.

Yes, using lr argumentation make TE LR and UNet LR different. @AI-Casanova's comment is also right.

I'm not sure, but using get_scheduler_fix function in train_network.py properly is the way to applying LRs differently.

or directly lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=[lambda epoch: 0.5, lambda epoch: 1], last_epoch=-1, verbose=False)

BootsofLagrangian avatar Feb 12 '23 14:02 BootsofLagrangian

I also discovered there are two other adaptative method. I was shocked at how high the SGD method ramped up the LR to (1.03e+00) but the results were still good. My god!

Sample from SGD training:

grid-0432

bmaltais avatar Feb 12 '23 15:02 bmaltais

Link to python module for reference: https://pypi.org/project/dadaptation/

bmaltais avatar Feb 12 '23 15:02 bmaltais

I intuitively knew that there must be a way of adjusting learning rate in a context dependent manner, but knew I was far too uninformed to come up with one. This is definitely cool stuff.

AI-Casanova avatar Feb 12 '23 15:02 AI-Casanova

Quick comparison results from DAdaptAdam with TE:0.5 and UNet:1.0:

DAdaptAdam-1-1: loss: 0.125, dlr: 4.02e-05 DAdaptAdam-0.5-1: Loss: 0.124, dlr: 4.53e-05

DAdaptAdam-1-1: grid-0436

DAdaptAdam-0.5-1: grid-0434

I think the winner is clear. TE LR need to be half of UNet... but there might be more optimal settings.

Optimizer config for both was: optimizer = dadaptation.DAdaptAdam(trainable_params, lr=1.0, decouple=True, weight_decay=0, d0=1e-6)

I will redo the same test but with an optimizer config of: optimizer = dadaptation.DAdaptSGD(trainable_params, lr=1.0, weight_decay=0, d0=1e-6)

bmaltais avatar Feb 12 '23 16:02 bmaltais

@bmaltais how did you implement the split learning rate? Or did you run it twice?

AI-Casanova avatar Feb 12 '23 17:02 AI-Casanova

@AI-Casanova I did it with

lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=[lambda epoch: 0.5, lambda epoch: 1], last_epoch=-1, verbose=False)

bmaltais avatar Feb 12 '23 17:02 bmaltais

@bmaltais awesome! I should have pulled on that thread, but my self taught lr for all things python and ML is already through the roof. 😅

AI-Casanova avatar Feb 12 '23 18:02 AI-Casanova

Here is an interesting finding. For DAdaptSGD having a TE and UNet lambda both at 1 is better than 0.5,1...

DAdaptSGD-1-1: grid-0438

DAdaptSGD-0.5-1: grid-0437

I wonder if having a weaker UNet with DAdaptSGD might be even better... like DAdaptSGD-1-0.5

Also, I have not been able to get anything out of DAdaptAdaGrad yet.

bmaltais avatar Feb 12 '23 19:02 bmaltais

And here are teh results of DAdaptSGD-1-0.5:

grid-0439

I think DAdaptSGD-1-1 is still the best config for that method.

Well... I am looking at the results and I am not so sure anymore... Maybe DAdaptSGD-1-0.5 is better...

bmaltais avatar Feb 12 '23 19:02 bmaltais

SGD is stochastic gradient descent right? Is that the same concept as SGD=(batch=1)?

Or is SGD scheduling about not having a weight decay like Adam?

Is batch=1 even SGD with Adam?

Primary sources are impenetrable and secondary sources so unreliable on this stuff. th-3506931634

AI-Casanova avatar Feb 12 '23 19:02 AI-Casanova

Good question... I don't really know. But DAdaptAdam-0.5-1 appear to produce the most likeness of all the method... so I might stick with that for now...

bmaltais avatar Feb 12 '23 20:02 bmaltais

Published 1st model made with this new technique: https://civitai.com/models/8337/kim-wilde-1980s-pop-star

bmaltais avatar Feb 12 '23 20:02 bmaltais

I'm experiencing what I think is a way overtrained TE, even at 0.5. All styling goes out the window before my UNet catches up.

I have to figure out how to log what the learning rates are independently.

AI-Casanova avatar Feb 13 '23 15:02 AI-Casanova

So @BootsofLagrangian was outputting the TE learning rate to the progress bar and logs, so what I thought was a suspiciously high UNet lr was an insanely high TE lr

Dropped my scale to .25 .5 and trying again.

AI-Casanova avatar Feb 13 '23 16:02 AI-Casanova

Unfortunately it's starting to look to me like I've replaced one grid search with another, with scaling factor in the place of lr

AI-Casanova avatar Feb 13 '23 17:02 AI-Casanova

@AI-Casanova, you might need another learning rate scheduler. My fork use only LambdaLR(identity or scalar scaling).

This is a problem, because of no using get_shceduler_fix function in sd-scripts.

Usually, Transformer models use warmup LR scheduler.

From dadaptation repo, applying LR scheduler using before also works fine.

BootsofLagrangian avatar Feb 14 '23 03:02 BootsofLagrangian

@BootsofLagrangian basically what I was seeing is very good likenesses being made, but they were so inflexible.

I think I might have hit the sweet spot at 0.125 0.25 though.

It still adjusts to my datasets, and is in a similar range as before.

Now I'm gonna add a few other ideas to this fork.

AI-Casanova avatar Feb 14 '23 04:02 AI-Casanova

@BootsofLagrangian I have tried the forked one and it seems to work wrong when the value of network_alpha is not equal to the value of network_dim. Is it an expected behavior that the smaller the value of network_alpha, the higher the learning rate?

When network_dim=128, network_alpha=1, data was destroyed about 50 steps were executed.

image

tsukimiya avatar Feb 15 '23 22:02 tsukimiya

@BootsofLagrangian I have tried the forked one and it seems to work wrong when the value of network_alpha is not equal to the value of network_dim. Is it an expected behavior that the smaller the value of network_alpha, the higher the learning rate?

When network_dim=128, network_alpha=1, data was destroyed about 50 steps were executed.

image

D-adaptation use inverse of subgradient of models. If you want more equations, details are in dadaptation paper

LoRA model is multiplicated two matrix with low-rank(r) B and A.

In LoRA paper, alpha and rank used external multiplication terms of model.

Alpha used multiplying model and rank used dividing model.

So, alpha/rank ratio is very directly and sensitively acting on subgradient.

In destroyed case, alpha=1, rank=128, alpha/rank ratio is 1/128. This makes subgradient smaller.

Now, return to D-adaptation. Small subgradient makes learning rate higher. High learning rate blow model up.

Therefore, It is highly recommended alpha and rank set up same value, especially using big(?) rank value.

Thank you for comment and experiments! :)

BootsofLagrangian avatar Feb 16 '23 14:02 BootsofLagrangian

Now, return to D-adaptation. Small subgradient makes learning rate higher. High learning rate blow model up.

Therefore, It is highly recommended alpha and rank set up same value, especially using big(?) rank value.

Understood. If that is the case, it would be better to have a warning when the alpha option is specified small, etc. when actually incorporating the code.

Thanks for your reply!

tsukimiya avatar Feb 16 '23 15:02 tsukimiya