No clear way to load models
🚀 Feature Request
Loading models is a bit of a pain right now. It's done differently in multiple scripts (including our internal eval scripts). Not all ways are compatible with all checkpoint forms.
This typically requires setting a TON of command line args based on what the model checkpoints need (--model-parallel, --ddp-backend fully_sharded, --distributed-port, etc.). Many of these args can be picked up by just looking at the files.
Afterwards we should refactor a few scripts to use this One True Method
Any way to provide the different eval scripts? :-)
Is this related to https://github.com/facebookresearch/metaseq/issues/73 ?
I can not find metaseq-api-local.py anywhere in OPT/
From #277
We should make model loading "just work". I shouldn't need to pass so many args to get it to find the right checkpoint. I should be able to specify sharded checkpoints by pointing to the shard0-rank0 pt.
Types of model checkpoints
We currently have three types of model checkpoints -
1. Singleton checkpoint - For example, the 355M checkpoint. The file format here is like reshard.pt .
2. Unsharded model parallel checkpoint - The file format here is like reshard-model_part-*.pt where * goes from 0 to number_of_model_parts - 1 .
3. Sharded model parallel checkpoint - The file format here is like reshard-model_part-0-shard0.pt , where the model part and shard numbers range over the number of model parallel parts and fully sharded data parallel shards respectively.
Here, the name "reshard" is just a convention. It can be any name. For example - "125m-model_part-0-shard0.pt"
How do we determine the type of model checkpoint?
cfg.common.model_parallel_size - Which determines the model parallel size. If this is 1, we can infer that the model is not model parallel. However, it might still be sharded through FSDP.
cfg.checkpoint.checkpoint_shard_count - Which determines the number of FSDP shards we have for the model. For model parallel models, each model part has these many shards.
If both these parameters are 1, we have a singleton model.
Both these config values can be determined from the model checkpoint itself.
cfg.distributed_training.use_sharded_state - if True, then state_dict will return FSDP.local_state_dict and load_state_dict will call FSDP.load_local_state_dict. Otherwise, state_dict will return the full model weights on data parallel rank 0 (empty on other ranks) and load_state_dict will broadcast model weights from rank 0 to other ranks.
From metaseq/distributed/fully_sharded_data_parallel.py