ColossalAI icon indicating copy to clipboard operation
ColossalAI copied to clipboard

text_encoder was not finetuned with script examples/images/dreambooth/train_dreambooth_colossalai.py,[BUG]:

Open SummerTrains opened this issue 2 years ago • 10 comments

🐛 Describe the bug

When I train dreambooth with script examples/images/dreambooth/train_dreambooth_colossalai.py, I found that the text_encode model file text_encoder/pytorch_model.bin has not changed. It is means the text encoder has not been optimized.

Then I found that only unet has been warpped with GeminiAdamOptimizer (https://github.com/hpcaitech/ColossalAI/blob/main/examples/images/dreambooth/train_dreambooth_colossalai.py#L475), but the text encoder not.

so how can I finetune both unet and text encoder?

Environment

No response

SummerTrains avatar Mar 16 '23 09:03 SummerTrains

Hi, currently the setting for text encoder is requires_grad=False as shown in line. Can you try making them trainable and later save the model states accordingly?

JThh avatar Mar 17 '23 03:03 JThh

@JThh thx! I have setted text_encoder.requires_grad_(True) and text_encoder.train(), but parameter of text-encoder was not changed. I print text_encoder.text_model.encoder.layers[3].self_attn.q_proj.weight, it has not been changed!

SummerTrains avatar Mar 20 '23 07:03 SummerTrains

@JThh And I think, it is that the text-encoder has not been warpped with GeminiAdamOptimizer or other optimizer

SummerTrains avatar Mar 20 '23 07:03 SummerTrains

Yes, @SummerTrains , you are right. Can you try imitating what was done to this file?

JThh avatar Mar 21 '23 12:03 JThh

@JThh The most important code is params_to_optimize = ( itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters() ), and sand it into torch.optimizer. But it's different with GeminiAdamOptimizer.

The key point is that ZeroOptimizer , the super class of GeminiAdamOptimizer, has a module(ZeroDDP) arguement, not a list of module(ZeroDDP).

So how can warp the both parameters of text_encoder and unet into a GeminiAdamOptimizer

SummerTrains avatar Mar 22 '23 04:03 SummerTrains

Would you please change this __init__ method to possibly be:

    def __init__(self, models: Union[List[torch.nn.Module], torch.nn.Module] , **defaults: Any) -> None:
        if isinstance(models, list):
            parameters = itertools.chain(*[module.parameters() for module in models])
        else:
            parameters = models.parameters()
        optimizer = HybridAdam(parameters, **defaults)
        super().__init__(optimizer, model, **defaults)

, and import itertools as well as from typing import List, Union accordingly?

I will submit a change to this code file formally soon.

JThh avatar Mar 22 '23 07:03 JThh

Would you please change this __init__ method to possibly be:

    def __init__(self, models: Union[List[torch.nn.Module], torch.nn.Module] , **defaults: Any) -> None:
        if isinstance(models, list):
            parameters = itertools.chain(*[module.parameters() for module in models])
        else:
            parameters = models.parameters()
        optimizer = HybridAdam(parameters, **defaults)
        super().__init__(optimizer, model, **defaults)

, and import itertools as well as from typing import List, Union accordingly?

I will submit a change to this code file formally soon.

It's not work. Because ZeroOptimizer take module: ZeroDDP as input. And I change super().__init__(optimizer, models[0], **defaults), it raise some other error.

So, looking for your MR as soon as possible.

SummerTrains avatar Mar 22 '23 09:03 SummerTrains

@JThh hello, how is going? When will you finish this problem?

SummerTrains avatar Mar 29 '23 06:03 SummerTrains

I am getting to this now.

I'd think it would bring about some refactoring of the codes. Can we take another route (such as ZeRO1)?

An example:

>>> my_optim = HybridAdam(itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters(), lr = 1e-3)
>>> zero_unet = zero_model_wrapper(unet, zero_stage=1)
>>> zero_encoder = zero_model_wrapper(text_encoder, zero_stage=1)
>>> zero_optim = zero_optim_wrapper(zero_unet, my_optim)  # ignore zero_unet here as in zero stage 1 it will be not used.

Can you test this out?

JThh avatar Apr 06 '23 17:04 JThh

hi @JThh , Thanks for your reply. Have you solved this issue?

hdjsjyl avatar Nov 22 '23 10:11 hdjsjyl