ColossalAI icon indicating copy to clipboard operation
ColossalAI copied to clipboard

[BUG]: does colossalai only support full layers training?

Open zhangvia opened this issue 1 year ago • 3 comments

🐛 Describe the bug

i'm using the colossalai to train the animate anyone. and the training need to freeze some layers in model. but that would lead to assert error in colossalai whichever plugin i'm using. so i want to that can colossalai support training part of layers in model?

Environment

none

zhangvia avatar Feb 19 '24 07:02 zhangvia

Hi zhangvia, could you please provide a detailed description of this bug? e.g. which layers you want to freeze and why? Maybe you can create an repo for that and our tech team will help you to analysis the problem in their spare time.

If you are referring to lazy init, Colossal-AI do allow some layers to be initialized lazily and use. If you are using GeminiPlugin, please see this for more information #5290

Thank you very much for using Colossal-AI~

Yanjia0 avatar Feb 23 '24 09:02 Yanjia0

Hi zhangvia, could you please provide a detailed description of this bug? e.g. which layers you want to freeze and why? Maybe you can create an repo for that and our tech team will help you to analysis the problem in their spare time.

If you are referring to lazy init, Colossal-AI do allow some layers to be initialized lazily and use. If you are using GeminiPlugin, please see this for more information #5290

Thank you very much for using Colossal-AI~

i see that issue. i wonder know if colossalai support part of layers training exactly because i see that issue. i open a issue #5226, but no one anwser that. when i using deepspeed to traing the animate anyone, it cost more memory in single gpu than single gpu training. and using cpu offload doesn't help. so i want to use colossalai gemini strategy to train animate anyone. but it seems colossalai also can't deal with this situation

zhangvia avatar Feb 27 '24 02:02 zhangvia

You can use TorchFSDP plugin(it can freeze layers), and maybe you also need this:

model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})

airlsyn avatar Mar 04 '24 06:03 airlsyn