fms-fsdp
fms-fsdp copied to clipboard
A write-up on Meta Device Init x Pretraining
Scope
This write-up only applies to "initial model init". For cases that require loading a checkpoint (continue-pretraining, fine-tuning and inference), this is not needed as any init would be overwritten by ckpt. Therefore, this mainly target "pretraining from scratch".
Background on meta device init
There are two ways to leverage meta device to save cpu memory during model init:
- create a full model copy on rank0 while put model on meta device on all other ranks.
- put model on meta device on all ranks, including rank0.
The first method init a full model on rank0 and utilize sync_module_states=True
during FSDP call to broadcast model from rank0 to all other ranks. This saves cpu memory from world_size total copies to only 1 copy.
The second method puts model on meta device on all ranks (including rank0), and utilize proper param_init_fn and/or post-FSDP init.
Comparing to the first method, the second one not only saves cpu memory (0 copy), but also greatly saves model init time, as this avoids initialing a full model on cpu (for large models like 70b, this could take 20 mins)
Method 2 is both more efficient and better cpu-mem saving, however, it can be very tricky to properly set up for pretraining and it might cause silent problems. Unlike continue-pretraining/fine-tuning/inference where model init isn’t important as it will be overwritten by loaded ckpt, pretraining requires proper model init which is very crucial. And model init for method 2 can be tricky no matter which stage you want to apply init:
pre-FSDP init
This isn't possible with method 2 as all ranks are using meta device before FSDP call. And this is also the reason that method 1 is much safer: you do all you want before the FSDP call as the model was still a full copy sitting on cpu. you can perform any init you need and it will be properly broadcast to other ranks during FSDP call. But again, we want method 2 and we don't want any cpu copy, so we will pass on this.
during-FSDP init
This is achieved by leveraging param_init_fn
, which will be performed on "to be materialized modules". Since we need to materialize and put on device first (as full model is on meta device), such param_init_fn is typically something like:
def param_init_fn(module):
module.to_empty(device)
module.init # e.g. module.reset_parameters()
here comes the tricky part where we might get silent problems. param_init_fn
will be performed on all to-be-materialized-modules, which pop/deque in a top-down/parent-children fashion (reference). Although this is already a great improvement from old times when we started the work (this has a very great detailed explanations on some old issues which we also observed and had to conquer), yet current design still requires a hidden-user-agreed-contract that "param_init_fn should only initialize module's own parameters/buffers but not any of the sub-modules". Another implicit requirement is we need to have such "init" defined on all possible modules. So what would happen here if we don't follow strictly to the rules here, like what we have now in FMS?
sub-modules would be re-init multiple times. Our reset_parameters()
is designed in a way that calling model.reset_parameters()
would init the full model with true/desired init. Similarly, Llama_Block.reset_parameters()
would init the full block. This is desired as typically we want this single line model wise init. And this works well for method 1. But imagine what would happen here if we use it as param_init_fn: recall the "to be materialized modules" will be something like [LLaMABlock, MultiHeadAttention, Linear, Linear, Linear, etc.], so children modules like "Linear" will be re-init multiple times and this can be problematic:
- issues discussed in the reference I shared above.
- more importantly: silent problems if we don't provide init all FULL coverage. Again, recall the fact that we defined our "init" on model level (
llama.reset_parameters()
) and "key module" levels (attn_block, embed, mlp_block, layer_norm) as that was typically sufficient, but these will be "silently" overwritten by lower level modules (e.g. Linear) because basic modules like Linear has their own implementation ofreset_parameters()
. so during this "re-init" on these "leaf nodes", wrong init will overwrite our true init, and this is silent!
post-FSDP init
This can be more tricky. This is less preferred than using param_init_fn
so I am not going into too much details. But trying to do post-FSDP init involves manipulating model params outside forward/backward which you will run into issues like "illegal memory access" as the model is already sharded. And you could technically leverage FSDP.summon_full_params()
with some "writebacks" to achieve some, but that is less-efficient and less-user-friendly than leveraging param_init_fn
. So this is also not wanted.
what to do with FMS
so it seems "during-FSDP init with param_init_fn
" is the way to go, but we would have to meet the contract:
- rewrite ALL init (reset_parameters) to be non-recursive.
- provide FULL coverage for init.
Is there a way to avoid doing so? and potentially re-use our existing recursive version? Well, the answer is yes, and the trick here turns out to be simple: we just need to add a "filter" to make sure param_init_fn
is recursively applied to modules that are mutually exclusive but cover 100% of the params. This way, no re-init would ever happen.
def param_init_fn(module):
if (
# provide the modules that are mutually exclusive but also cover 100% of the model params
isinstance(module, MultiHeadAttention)
or isinstance(module, WordEmbedding)
or isinstance(module, GatedLinearUnit)
or isinstance(module, LayerNormParameterized)
):
module.to_empty(device=torch.cuda.current_device())
with torch.no_grad():
module.reset_parameters()