Configure arbitrary frozen modules via config
I saw there is an issue about this #306 but it has not been implemented yet so I created this PR.
It allow you to specify a list of modules to be frozen via config file or command line, like: --model.frozen_modules='tok_embeddings,layers.0.attention'
Also print the number of frozen and trainable parameters during initialization.
- No frozen modules.
[model] frozen_modules = 'tok_embeddings,tik_embeddings,layers.0.attention'
Hope this help.
I'm not sure if this should be a configurable option. Instead, if a model requires some parts to be frozen, it should be coded in the model. And our trainer should be able to support different use cases, including parts of parameters are frozen and load the checkpoints correctly.
Whether or not this should be configurable via toml depends on the use case. As @fegin pointed out, in most use cases we've heard, e.g. freezing certain parts (such as image encoder) of a multimodal/diffusion model, it will be a static decision and the model code should handle it. It should only be configurable if the training parameters would change / shift in the training process. Can you give some examples of such cases?
We will anyway need this mechanism down the road, but probably first via a util function (called from model code), similar to this function in torchtune. https://github.com/pytorch/torchtune/blob/main/torchtune/modules/peft/_utils.py#L65
Understand if you would want this in the model code. But my use cases are a bit more dynamic so I implement this in a generic manner. This way there is one less thing to duplicate when we switch between different model.s Even for training the same Multimodal model, there are usually multiple stages of training given the same modelling code. For example some actual scenarios:
- Freeze both encoder, decoder, only train projector/connector.
- Freeze everything except Patch Embed to adapt data domain.
- Freeze decoder, finetune encoder + projector etc.
including parts of parameters are frozen and load the checkpoints correctly.
This case is indeed cleaner to do in the model code.