metaseq
metaseq copied to clipboard
fix: ensure last checkpoint is always saved, refactor training stop conditions to be computed in single location
Issues
1 Inconsistent checkpoint filenames saved by trainer
In our pipeline we often have sequence of steps such as (train, reshard/unflatten, evaluate). The output files of the training become inputs to the resharding scripts. In order for the execution to work reliably the output files need to have consistent filenames, such as checkpoint_last-model_part-0-shard0.pt
When running metaseq.cli.train with tasks such as streaming_finetune_language_modeling there are two different stopping conditions set by --max-epochs and --max-updates. Whichever limit is hit first will cause the model stop training.
The issue is that checkpoint_last-* file is ONLY written the epoch stop condition or update stop conditions were false. This couples the checkpoint filename with the stopping conditions
Notice checkpoints[0]
only uses the FIRST true filename/condition
https://github.com/facebookresearch/metaseq/blob/c16d21047d975b7a925648f38cba3190a8ef27d6/metaseq/checkpoint_utils.py#L89-L99
Goal
We want to be able to run the jobs/pipeline and change the stopping conditions without implicitly changing the output file that will be given to the subsequent commands / scripts
2 Training Stop was Handled in Multiple Locations
Loop condition: https://github.com/facebookresearch/metaseq/blob/c16d21047d975b7a925648f38cba3190a8ef27d6/metaseq/cli/train.py#L209 Loop break: https://github.com/facebookresearch/metaseq/blob/c16d21047d975b7a925648f38cba3190a8ef27d6/metaseq/cli/train.py#L212-L213
This makes it harder to reason about which condition will cause training to stop.
Solution
- Consolidate all training stop conditions in to
validate_and_save
andshould_stop
- Use
>
instead of>=
conditions
- Use
- Change to always save
checkpoint_last*
file- This means there are cases it will save multiple checkpoints
- Epoch AND last checkpoint
- Updates AND last checkpoint
- This means there are cases it will save multiple checkpoints
Testing
I wasn't able to test since this is merging with metaseq main instead of our fork's main. I wanted to at least share the ideas. Although changing training stop conditions can be serious, so maybe someone else can submit a small jobs to test. One with max-epochs, other with max-updates, and in both cases it saves checkpoint_last files
Related to #726