sd-scripts icon indicating copy to clipboard operation
sd-scripts copied to clipboard

--stop_text_encoder_training support for LoRA/LyCORIS/etc training

Open sangoi-exe opened this issue 1 year ago • 24 comments

Hello,

I used the changes made by @larsupb, that were mistakenly proposed in @bmaltais's repository, with some adjustments I did to fit the current state of train_network.py, to enable support for --stop_text_encoder_training in LoRA training.

https://github.com/bmaltais/kohya_ss/pull/1543

I did some tests and it seems to be working fine.

sangoi-exe avatar Dec 21 '23 06:12 sangoi-exe

\train_network.py", line 747, in train text_encoder.requires_grad_(False) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ AttributeError: 'list' object has no attribute 'requires_grad_' When training sdxl's lora

suede299 avatar Jan 18 '24 06:01 suede299

\train_network.py", line 747, in train text_encoder.requires_grad_(False) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ AttributeError: 'list' object has no attribute 'requires_grad_' When training sdxl's lora

Yep, same here.

edit: I am working with this now and the XL fix for this oversight is

                        for tensor in text_encoder:
                            tensor.requires_grad_(False)

I will see if his pr works.

DarkAlchy avatar Feb 07 '24 01:02 DarkAlchy

@DarkAlchy, sorry, the solution wasn't entirely developed by me, I used an idea that was sent to bmaltais's repository, but with some adjustments because of the many updates that Kohya's scripts have had since then.

Another detail is that I only train 1.5, so I ended up not testing the modifications with the SDXL.

sangoi-exe avatar Feb 07 '24 04:02 sangoi-exe

@DarkAlchy, sorry, the solution wasn't entirely developed by me, I used an idea that was sent to bmaltais's repository, but with some adjustments because of the many updates that Kohya's scripts have had since then.

Another detail is that I only train 1.5, so I ended up not testing the modifications with the SDXL.

I figured. I am not entirely sure it works as I am in Dreambooth. My fix works to take care of the issue of the error, but I had to modify the tensorboard section to get it to stop reporting on te when it hits the step. Thing is I could not find the script it calls to train the TE itself which is what we need to get to and alter, so lr is no longer updated for te and it stays where it is. for the graph in tensorboard that would show a flat line and would not require any changes in that section. To be honest, Kohya needs to just do it as it wouldn't be hard for the originator of the code we use as it is for one of us to read this huge mess. I know I have a hard time in it.

DarkAlchy avatar Feb 07 '24 04:02 DarkAlchy

Thank you for opening this PR. I'm not sure text_encoder.requires_grad_(False) stops the training for Text Encoder. It may be needed to set requires_grad_ for network. I will test it sooner.

kohya-ss avatar Feb 07 '24 12:02 kohya-ss

Thank you for opening this PR. I'm not sure text_encoder.requires_grad_(False) stops the training for Text Encoder. It may be needed to set requires_grad_ for network. I will test it sooner.

You're welcome! The code is kinda like a copy of the method used in DreamBooth to stop the TE training, just not sure if it works as expected. I can only confirm that the training changes somehow when the TE stop is activated.

sangoi-exe avatar Feb 07 '24 13:02 sangoi-exe

Thank you for opening this PR. I'm not sure text_encoder.requires_grad_(False) stops the training for Text Encoder. It may be needed to set requires_grad_ for network. I will test it sooner.

You're welcome! The code is kinda like a copy of the method used in DreamBooth to stop the TE training, just not sure if it works as expected. I can only confirm that the training changes somehow when the TE stop is activated.

All it is doing is yanking the TE from the gradients being update (hence the warning that grads will be zero). It still trains TE only they should be 0 weight gradients when loss does its thing. It is a brutish hack that should work, but Kohya really needs to properly implement this.

DarkAlchy avatar Feb 07 '24 14:02 DarkAlchy

Could this be considered again for inclusion? It would be really helpful to be able to stop training the text encoder at a certain point but continue to train the unet. Unfortunately the existing network_train_unet_only option won't do it, when resuming from an existing lora, since it will just discard all the text encoder keys.

mx avatar Mar 04 '24 06:03 mx

I've developed a much simpler modification to halt the text encoder.

By leveraging the principle of underflow, I've implemented a mechanism where, upon reaching the designated stopping step for the text encoder, its learning rate is reduced to 1e-10.

It's important to note that this mechanism is specific ONLY to FP16 usage. Those utilizing BF16 or even FP32 should adjust this reduction to an even smaller number, like, 1e-40, on line 819 and 820.

Probably would be better to create a global variable for it, but, w/e.

sangoi-exe avatar Mar 04 '24 06:03 sangoi-exe

@DevArqSangoi If you do it that way, could active regularization effects still reduce the weights, depending on the optimizer in use? That would be a problem if they could.

mx avatar Mar 04 '24 07:03 mx

As far as I know, which isn't much, I think that max grad and the scale weight normal (if that's what you're referring to) don't modify the weights, they modify the update, and if the learning rate is in underflow, the update will theoretically be 0.

sangoi-exe avatar Mar 04 '24 07:03 sangoi-exe

It shouldn't matter for Adam or AdamW or the optimizers based on them, either. Just not sure what weirdness is out there. I suppose it should really work fine for everything.

mx avatar Mar 04 '24 07:03 mx

I switched to One Trainer and my trainings are so much better, and this is partially why. You really do need to stop TE not ignorantly (as some have said to me) use a lower LR due to how adaptive works.

DarkAlchy avatar Mar 04 '24 10:03 DarkAlchy

Btw, I think I used an old version of train network in this PR 😂 too much lines changed but the workaround is like 10~20 lines

sangoi-exe avatar Mar 05 '24 03:03 sangoi-exe

I worked on the A1111 Dreambooth extension and I'm pretty sure what all we did to stop training the tenc there was to set text_encoder.train(false) which puts it into inference mode. text_encoder.eval() works too IIRC. No more update. That should work fine here and it works for both Dreambooth and LORA. Do it for both tencs if necessary.

Easiest way to implement this would be to put text_encoder.train($train_tenc) before the training loop and set to true if it's being trained and calculate $stop_tenc_step/epoch. At the end of each step and if $step/epoch >= $stop_tenc_step/epoch set $train_tenc to false.

saunderez avatar Mar 23 '24 00:03 saunderez

I worked on the A1111 Dreambooth extension and I'm pretty sure what all we did to stop training the tenc there was to set text_encoder.train(false) which puts it into inference mode. text_encoder.eval() works too IIRC. No more update. That should work fine here and it works for both Dreambooth and LORA. Do it for both tencs if necessary.

Easiest way to implement this would be to put text_encoder.train($train_tenc) before the training loop and set to true if it's being trained and calculate $stop_tenc_step/epoch. At the end of each step and if $step/epoch >= $stop_tenc_step/epoch set $train_tenc to false.

Why such a brute force method? That is like taking a chisel to drive a nail into a bolder? It is really not a good idea to just turn it into eval mode when you really only have to skip over the TE training section when you pass the point set. Just do nothing to the tes iow.

DarkAlchy avatar Mar 23 '24 02:03 DarkAlchy

I worked on the A1111 Dreambooth extension and I'm pretty sure what all we did to stop training the tenc there was to set text_encoder.train(false) which puts it into inference mode. text_encoder.eval() works too IIRC. No more update. That should work fine here and it works for both Dreambooth and LORA. Do it for both tencs if necessary. Easiest way to implement this would be to put text_encoder.train($train_tenc) before the training loop and set to true if it's being trained and calculate $stop_tenc_step/epoch. At the end of each step and if $step/epoch >= $stop_tenc_step/epoch set $train_tenc to false.

Why such a brute force method? That is like taking a chisel to drive a nail into a bolder? It is really not a good idea to just turn it into eval mode when you really only have to skip over the TE training section when you pass the point set. Just do nothing to the tes iow.

What text encoder training section? The text encoder isn't being trained independently it's part of the optimiser which is trained as a unit, you're doing the same thing every step. If you told the optimiser to train the text encoder when you set it up and it has a LR that is not 0 it will train. Otherwise it won't. So telling the optimizer not to train it is exactly how to should be done and in practice eval and train(false) achieve the same result.

I'm open to other suggestions but I'm not seeing a whole lot of options. What method would you suggest? You could set LR to 0 via the param group, I don't think that's a particularly better solution though.

saunderez avatar Mar 24 '24 02:03 saunderez

Well, look at how OneTrainer does it as it does not just sledge hammer it when it hits the step/epoch/etc... as this method is doing. I never looked at their code, but I can say you will not get the message about gradients no longer being tracked so they must be taking a different approach than this.

DarkAlchy avatar Mar 24 '24 06:03 DarkAlchy

I checked OneTrainer's code and all they do is set requires_grad to false. I checked the current implementation of sd_dreambooth_webui and they do they same thing. You said this was brutish so we're back to square one where theres no elegant solution.

OneTrainer in StableDiffusionXLLoraSetup.py

if model.text_encoder_1_lora is not None: train_text_encoder_1 = config.text_encoder.train and \ not self.stop_text_encoder_training_elapsed(config, model.train_progress) model.text_encoder_1_lora.requires_grad_(train_text_encoder_1) if model.text_encoder_2_lora is not None: train_text_encoder_2 = config.text_encoder_2.train and \ not self.stop_text_encoder_2_training_elapsed(config, model.train_progress) model.text_encoder_2_lora.requires_grad_(train_text_encoder_2) if model.unet_lora is not None: train_unet = config.unet.train and \ not self.stop_unet_training_elapsed(config, model.train_progress) model.unet_lora.requires_grad_(train_unet)

sd_dreambooth_webui The relevant code in train_dreambooth.py

if args.use_lora: if not args.lora_use_buggy_requires_grad: set_lora_requires_grad(text_encoder, train_tenc) # We need to enable gradients on an input for gradient checkpointing to work # This will not be optimized because it is not a param to optimizer text_encoder.text_model.embeddings.position_embedding.requires_grad_(train_tenc) if args.model_type == "SDXL": set_lora_requires_grad(text_encoder_two, train_tenc) text_encoder_two.text_model.embeddings.position_embedding.requires_grad_(train_tenc) else: text_encoder.requires_grad_(train_tenc) if args.model_type == "SDXL": text_encoder_two.requires_grad_(train_tenc)

which calls this function in lora.py

def set_lora_requires_grad(model, requires_grad): for name, param in model.named_parameters(): if "lora" in name: if param.requires_grad != requires_grad: param.requires_grad = requires_grad

I maintain .train(false) isn't any worse, its used if you aren't training the unet or text_encoder and it doesn't throw any error messages if you do it mid run. At the end of the day I don't care what the implementation is as long as it works so please just get something commited so Kohya LORA trainers aren't second class citizens.

saunderez avatar Apr 02 '24 16:04 saunderez

It is a brutish way. I didn't say it wouldn't work, just it is a sledgehammer for the lazy. The real way no current "dev" wants to tackle if there is an easy way out for them (the path of least resistance) which isn't the best way.

edit: That last sentence sums it up for me as well. Why in the heck has this not been done in Kohya yet? I used OT and tell you what stopping the unet, or TE, is a valuable tool in our tool belts that is sorely needed. Not sure why Kohya is still taking a dump on the toilet for this long.

DarkAlchy avatar Apr 03 '24 12:04 DarkAlchy

It is a brutish way. I didn't say it wouldn't work, just it is a sledgehammer for the lazy. The real way no current "dev" wants to tackle if there is an easy way out for them (the path of least resistance) which isn't the best way.

Setting either train or requires_grad_ to false is the usual way to turn off the training of something in Torch. I have no idea what you mean by "brutish" or "sledgehammer" or why the implication that this is somehow bad, this is normal procedure.

mx avatar Apr 03 '24 15:04 mx

It is a brutish way. I didn't say it wouldn't work, just it is a sledgehammer for the lazy. The real way no current "dev" wants to tackle if there is an easy way out for them (the path of least resistance) which isn't the best way.

Setting either train or requires_grad_ to false is the usual way to turn off the training of something in Torch. I have no idea what you mean by "brutish" or "sledgehammer" or why the implication that this is somehow bad, this is normal procedure.

Is this an elegant way? Come on, tell me. If this were an elegant way, then it would not warn you when it happens since you told it to do it. You just seem to want to argue for the sake of arguing, so go troll elsewhere.

Let's hope kohya finally implements it for unet and both text encoders.

DarkAlchy avatar Apr 03 '24 18:04 DarkAlchy

It is a brutish way. I didn't say it wouldn't work, just it is a sledgehammer for the lazy. The real way no current "dev" wants to tackle if there is an easy way out for them (the path of least resistance) which isn't the best way.

edit: That last sentence sums it up for me as well. Why in the heck has this not been done in Kohya yet? I used OT and tell you what stopping the unet, or TE, is a valuable tool in our tool belts that is sorely needed. Not sure why Kohya is still taking a dump on the toilet for this long.

I'm not sure attacking and making demands from the devs who do this work for free is appropriate or helpful. You seem to be happy with the OT, so you have an alternative.

Two separate people shared how this functionality is implemented elsewhere, including the OT that you said above works great for what you're doing. If using the suggested method produces a warning but the results are as expected (as they are for you in OT), then that isn't much of a problem, similar to the "A matching Triton is not available".

dkalintsev avatar Apr 03 '24 22:04 dkalintsev

It is a brutish way. I didn't say it wouldn't work, just it is a sledgehammer for the lazy. The real way no current "dev" wants to tackle if there is an easy way out for them (the path of least resistance) which isn't the best way. edit: That last sentence sums it up for me as well. Why in the heck has this not been done in Kohya yet? I used OT and tell you what stopping the unet, or TE, is a valuable tool in our tool belts that is sorely needed. Not sure why Kohya is still taking a dump on the toilet for this long.

I'm not sure attacking and making demands from the devs who do this work for free is appropriate or helpful. You seem to be happy with the OT, so you have an alternative.

Two separate people shared how this functionality is implemented elsewhere, including the OT that you said above works great for what you're doing. If using the suggested method produces a warning but the results are as expected (as they are for you in OT), then that isn't much of a problem, similar to the "A matching Triton is not available".

Who in the hell said it was a problem? I said it was a bruitish, non elegant way of doing it, and that is all I said, but you may now go argue with the wall.

DarkAlchy avatar Apr 04 '24 07:04 DarkAlchy