torchtune
torchtune copied to clipboard
Implement step based checkpointing
Context
What is the purpose of this PR? Is it to
- [x] add a new feature
- [ ] fix a bug
- [ ] update tests and/or documentation
- [ ] other (please add here)
Closes #2105. This is a widely requested feature that allows users to have greater control over checkpointing frequency in torchtune.
TODO: Add commentary on design decisions. Acknowledge spaghetti code. Beg forgiveness.
Changelog
- Update
FullModelHFCheckpointerto accept a step parameter when saving a checkpoint. Use that step to designate the checkpoint folder name. Keepepoch_{}as a fall-back for BC. - Modify the
full_finetune_single_device.pyrecipe to utilize step-based checkpointing. - Add tests for `full_finetune_single_device.py`` recipe w/ step-based checkpointing.
Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
- [x] run pre-commit hooks and linters (make sure you've first installed via
pre-commit install) - [x] add unit tests for any new functionality
- [x] update docstrings for any new or updated methods or classes
- [x] run unit tests via
pytest tests - [x] run recipe tests via
pytest tests -m integration_test - [x] manually run any new or modified recipes with sufficient proof of correctness
- [ ] include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)
Evidence of correct number of checkpoints being saved
(joe-torchtune) [[email protected] ~/projects/joe-torchtune (impl-step-based-ckpt)]$ ls /tmp/torchtune/llama3_2_1B/full_single_device/
step_100 step_125 step_150 step_175 step_200 step_25 step_50 step_75 torchtune_config.yaml
Evidence of correct resuming from ckpt mid-epoch
Evidence of correct resuming from ckpt at epoch boundary
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it. Here is a docstring example and a tutorial example
- [ ] I did not change any public API
- [x] I have added an example to docs or docstrings
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2384
- :page_facing_up: Preview Python docs built from this PR
Note: Links to docs will display an error until the docs builds have been completed.
:heavy_exclamation_mark: 1 Active SEVs
There are 1 currently active SEVs. If your PR is affected, please view them below:
:x: 2 New Failures, 2 Unrelated Failures
As of commit 650d91d9208a9e9e2fb47ecd483d08ffdf7d9528 with merge base 3d735916bd9efca600f79fe8d77a757c1160279a ():
NEW FAILURES - The following jobs have failed:
- GPU tests / gpu_test (3.11, stable) (gh)
tests/recipes/test_qat_lora_finetune_distributed.py::TestQATLoRAFinetuneDistributedRecipe::test_training_state_on_resume_with_async_checkpointing[llama3/8B_qat_lora-llama3-tune-False] - Lint / lint (3.10) (gh)
Process completed with exit code 1.
BROKEN TRUNK - The following jobs failed but were present on the merge base:
👉 Rebase onto the `viable/strict` branch to avoid these failures
- GPU tests / gpu_test (3.10, stable) (gh) (trunk failure)
##[error]The operation was canceled. - GPU tests / gpu_test (3.9, stable) (gh) (trunk failure)
##[error]The operation was canceled.
This comment was automatically generated by Dr. CI and updates every 15 minutes.
- [x]
recipe_stateis still saved to${output_dir}, not${output_dir}/step_XXX resume_from_checkpointlogic should be updated- RN it looks for
${output_dir}checkpoint, notstep_XXX - maybe replace top level
cfg.resume_from_checkpointto havecfg.checkpointer.resume_fromwhich is either "latest" (default) or the path to the checkpoint to resume from. Or separate mutually exclusiveresume_from: /path/andresume_from_latest: True - offtopic but
cfg.resume_from_checkpointis mentioned in code as deprecated and replaced byshould_load_recipe_statebut de factoresume_from_checkpointis mandatory andshould_load_recipe_statedoesn't work
- RN it looks for
- [x]
recipe_statehas proper step and epoch to continue from but the train cycle still starts from 0 -> logs start from 0 & checkpointing start from 0 - [x] lr schedulers aren't synced with the resume step
- maybe save the wandb run?..... 🥺
Codecov Report
Attention: Patch coverage is 26.06061% with 244 lines in your changes missing coverage. Please review.
Project coverage is 59.86%. Comparing base (
3134f90) to head (ce41c15). Report is 1 commits behind head on main.
Additional details and impacted files
@@ Coverage Diff @@
## main #2384 +/- ##
==========================================
- Coverage 60.14% 59.86% -0.29%
==========================================
Files 437 437
Lines 26912 27028 +116
==========================================
- Hits 16187 16181 -6
- Misses 10725 10847 +122
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
:rocket: New features to boost your workflow:
- :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
rebased to #2869