torchtune
torchtune copied to clipboard
Fixing counting number of batches for accumulation through epoch
Context
What is the purpose of this PR? Is it to
- [ ] add a new feature
- [x] fix a bug
- [ ] update tests and/or documentation
- [ ] other (please add here)
As I was running DPO and SFT, I noticed two surprising behaviours when running for several epochs:
- The loss would drop at the beginning of a new epoch;
- The accuracy (for DPO) would be >100% at the beginning of a new epoch.
I noticed that statistics accumulators re-initialization and zero-ing gradients would happen in this condition (e.g.
recipes/full_finetune_distributed.py):
for idx, batch in enumerate(self._dataloader):
...
# Optimizer step (if not fused in backward call)
if (idx+1) % self._gradient_accumulation_steps == 0:
...
Now the issue is that if your gradient accumulation parameter is set at say 8, but you only have 63 batches to process, it means that you process the last 7 batches, accumulating statistics and gradients, without re-initializing them before starting the new epoch, continuing accumulating stats and grads for the first 8 batches and then do a step. You end up accumulating 15 batches instead of 8, messing up not only with the statistics report but also with the optimisation. I see at least three possibilities:
- dropping the last few batches so that
len(self._dataloader) % self._gradient_accumulation_steps == 0; - counting the number of batches to accumulate not in terms of batch index but absolute number of processed batches.
- making a step with the last few batches by correctly scaling
Changelog
This PR implements the second option. It should be done for all recipe but I thought discussing the solution first was better.
Test plan
- [x] run pre-commit hooks and linters (make sure you've first installed via
pre-commit install) - [ ] add unit tests for any new functionality
- [ ] update docstrings for any new or updated methods or classes
- [x] run unit tests via
pytest tests - [x] run recipe tests via
pytest tests -m integration_test - [x] manually run any new or modified recipes with sufficient proof of correctness
- [x] include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)
UX
- [x] I did not change any public API
- [ ] I have added an example to docs or docstrings
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2745
- :page_facing_up: Preview Python docs built from this PR
Note: Links to docs will display an error until the docs builds have been completed.
This comment was automatically generated by Dr. CI and updates every 15 minutes.
Codecov Report
Attention: Patch coverage is 0% with 6 lines in your changes missing coverage. Please review.
Project coverage is 62.64%. Comparing base (
c8e670b) to head (72c1eea). Report is 9 commits behind head on main.
| Files with missing lines | Patch % | Lines |
|---|---|---|
| recipes/full_dpo_distributed.py | 0.00% | 3 Missing :warning: |
| recipes/full_finetune_distributed.py | 0.00% | 3 Missing :warning: |
Additional details and impacted files
@@ Coverage Diff @@
## main #2745 +/- ##
==========================================
+ Coverage 60.64% 62.64% +1.99%
==========================================
Files 428 430 +2
Lines 26091 26395 +304
==========================================
+ Hits 15823 16534 +711
+ Misses 10268 9861 -407
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
:rocket: New features to boost your workflow:
- :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
Thanks for your response, @joecummings Ok sure, I can suggest one way of doing it that induce minimal change.
Hey @joecummings , any update on this?
Hi @felipemello1 ,
Thanks for your answer. It seems to me like there is a confusion in the roles of max_steps_per_epoch/_steps_per_epoch. They're not here to count the number of processed batches but number of steps. So in your cases, if max_steps_per_epoch=40 or 100, it doesn't change anything.
The 2nd solution "counting the number of batches to accumulate not in terms of batch index but absolute number of processed batches." implied having to carryover the last few batches of an epoch. It seems to me like the solution you describe is actually the first one