torchtune
torchtune copied to clipboard
Support Optimizer-in-the-backward
Context
What is the purpose of this PR? Is it to
- [x] add a new feature
- [ ] fix a bug
- [ ] update tests and/or documentation
- [ ] other (please add here)
Enable Optimizer-in-the-backward for full_finetune_distributed
Changelog
- Update full_finetune_distributed for enabling Optimizer-in-the-backward
- Update test_full_finetune_distributed with
_optimizer_in_bwdconfig - updated test_distributed to test running with/without optimized_in_the_backward, and performance after saving-loading state_dict.
Test plan
- Test running with optimizer_in_the_backward:
tune run --nproc_per_node 2 full_finetune_distributed --config llama2/7B_full fsdp_cpu_offload=False max_steps_per_epoch=2 optimizer_in_bwd=True - Test running optimizer_in_the_backward with resume_from_checkpoint:
tune run --nproc_per_node 2 full_finetune_distributed --config llama2/7B_full fsdp_cpu_offload=False max_steps_per_epoch=2 epochs=10 optimizer_in_bwd=True resume_from_checkpoint=True checkpointer.recipe_checkpoint=/tmp/Llama-2-7b-hf/recipe_state.pt checkpointer.checkpoint_files=[hf_model_0001_1.pt,hf_model_0002_1.pt] - Verify that running with Optimizer-in-the-backward could have the same loss, model_state_dict and optimizer_state_dict, model after saving and loading could also have the same:
pytest tests/torchtune/training/test_distributed.py -k test_optimizer_in_backward
Memory cost analysis:
With each layer gradient cost 193MB memory, the origin(left) case has the peak memory at the 31th layer with accumulation of 193MB memory times 30.
The right case with optimizer-in-the-backward frees these memory during backward, gets lower peak memory.
Training time and loss analysis:
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1737
- :page_facing_up: Preview Python docs built from this PR
Note: Links to docs will display an error until the docs builds have been completed.
:white_check_mark: You can merge normally! (2 Unrelated Failures)
As of commit f639b6d205d5de8006987717b96e323692946cb8 with merge base f639b6d205d5de8006987717b96e323692946cb8 ():
BROKEN TRUNK - The following jobs failed but were present on the merge base:
👉 Rebase onto the `viable/strict` branch to avoid these failures
- Regression Tests / regression_test (3.11, nightly) (gh) (trunk failure)
tests/regression_tests/test_llama2_7b.py::TestLoRA7BDistributedFinetuneEval::test_finetune_and_eval - Regression Tests / regression_test (3.11, stable) (gh) (trunk failure)
tests/regression_tests/test_llama2_7b.py::TestLoRA7BDistributedFinetuneEval::test_finetune_and_eval
This comment was automatically generated by Dr. CI and updates every 15 minutes.
Codecov Report
Attention: Patch coverage is 1.97368% with 149 lines in your changes missing coverage. Please review.
Project coverage is 25.44%. Comparing base (
7cf656b) to head (207b1b1). Report is 21 commits behind head on main.
| Files with missing lines | Patch % | Lines |
|---|---|---|
| tests/torchtune/training/test_distributed.py | 2.88% | 101 Missing :warning: |
| recipes/full_finetune_distributed.py | 0.00% | 48 Missing :warning: |
Additional details and impacted files
@@ Coverage Diff @@
## main #1737 +/- ##
===========================================
- Coverage 69.33% 25.44% -43.89%
===========================================
Files 305 305
Lines 15892 16089 +197
===========================================
- Hits 11018 4094 -6924
- Misses 4874 11995 +7121
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
could we draw loss curves in weights & bias to showcase numerics are the same with/without optimizer-in-the-backward?
could we draw loss curves in weights & bias to showcase numerics are the same with/without optimizer-in-the-backward?
The loss curves have been added in the comments section of the third column on the right-hand side table.