torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

Fixing counting number of batches for accumulation through epoch

Open wesbz opened this issue 6 months ago • 3 comments
trafficstars

Context

What is the purpose of this PR? Is it to

  • [ ] add a new feature
  • [x] fix a bug
  • [ ] update tests and/or documentation
  • [ ] other (please add here)

As I was running DPO and SFT, I noticed two surprising behaviours when running for several epochs: image

  1. The loss would drop at the beginning of a new epoch;
  2. The accuracy (for DPO) would be >100% at the beginning of a new epoch. I noticed that statistics accumulators re-initialization and zero-ing gradients would happen in this condition (e.g. recipes/full_finetune_distributed.py):
for idx, batch in enumerate(self._dataloader):
    ...
    # Optimizer step (if not fused in backward call)
    if (idx+1) % self._gradient_accumulation_steps == 0:
        ...

Now the issue is that if your gradient accumulation parameter is set at say 8, but you only have 63 batches to process, it means that you process the last 7 batches, accumulating statistics and gradients, without re-initializing them before starting the new epoch, continuing accumulating stats and grads for the first 8 batches and then do a step. You end up accumulating 15 batches instead of 8, messing up not only with the statistics report but also with the optimisation. I see at least three possibilities:

  1. dropping the last few batches so that len(self._dataloader) % self._gradient_accumulation_steps == 0;
  2. counting the number of batches to accumulate not in terms of batch index but absolute number of processed batches.
  3. making a step with the last few batches by correctly scaling

Changelog

This PR implements the second option. It should be done for all recipe but I thought discussing the solution first was better.

Test plan

  • [x] run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • [ ] add unit tests for any new functionality
  • [ ] update docstrings for any new or updated methods or classes
  • [x] run unit tests via pytest tests
  • [x] run recipe tests via pytest tests -m integration_test
  • [x] manually run any new or modified recipes with sufficient proof of correctness
  • [x] include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

  • [x] I did not change any public API
  • [ ] I have added an example to docs or docstrings

wesbz avatar May 17 '25 08:05 wesbz

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2745

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

pytorch-bot[bot] avatar May 17 '25 08:05 pytorch-bot[bot]

Codecov Report

Attention: Patch coverage is 0% with 6 lines in your changes missing coverage. Please review.

Project coverage is 62.64%. Comparing base (c8e670b) to head (72c1eea). Report is 9 commits behind head on main.

Files with missing lines Patch % Lines
recipes/full_dpo_distributed.py 0.00% 3 Missing :warning:
recipes/full_finetune_distributed.py 0.00% 3 Missing :warning:
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2745      +/-   ##
==========================================
+ Coverage   60.64%   62.64%   +1.99%     
==========================================
  Files         428      430       +2     
  Lines       26091    26395     +304     
==========================================
+ Hits        15823    16534     +711     
+ Misses      10268     9861     -407     

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

:rocket: New features to boost your workflow:
  • :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

codecov-commenter avatar May 17 '25 15:05 codecov-commenter

Thanks for your response, @joecummings Ok sure, I can suggest one way of doing it that induce minimal change.

wesbz avatar May 19 '25 21:05 wesbz

Hey @joecummings , any update on this?

wesbz avatar Jun 16 '25 17:06 wesbz

Hi @felipemello1 , Thanks for your answer. It seems to me like there is a confusion in the roles of max_steps_per_epoch/_steps_per_epoch. They're not here to count the number of processed batches but number of steps. So in your cases, if max_steps_per_epoch=40 or 100, it doesn't change anything. The 2nd solution "counting the number of batches to accumulate not in terms of batch index but absolute number of processed batches." implied having to carryover the last few batches of an epoch. It seems to me like the solution you describe is actually the first one

wesbz avatar Jun 29 '25 17:06 wesbz