pytorch-lightning icon indicating copy to clipboard operation
pytorch-lightning copied to clipboard

Incorrect batch progress saved in checkpoint at every_n_train_steps

Open shuaitang5 opened this issue 2 years ago • 5 comments

Bug description

When saving a checkpoint at every_n_train_steps=3, it performs the checkpoint saving inside on_train_batch_end function in ModelCheckpoint class. During that checkpoint saving, the state dict of fit loop is snapshotted and saved, along with the batch progress of it. But the batch_progress is only incremented after on_train_batch_end is called/checkpoint is saved, thus the saved checkpoint having incorrect batch_progress which looks like this:

# in checkpoint file checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']
{total: {ready: 3, completed: 2, started: 3, processed: 3}}

And the expected value should be: {total: {ready: 3, completed: 3, started: 3, processed: 3}}, which is what the checkpoint saved after validation contains.

This causes an issue that when we resume from batch_end checkpoint, the starting batch_idx is 2 while the global step is 3 in training_step function in model module (they should match), and following saved checkpoint all having incorrect step value in file name. This doesn't seem like expected behavior, am I missing something?

I'm currently using a hack in the on_train_batch_end override function like this to overcome this issue:

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx) -> None:
        # hack: https://github.com/Lightning-AI/lightning/blob/master/src/lightning/pytorch/loops/training_epoch_loop.py#L233-L237
        # At the time this function is called, the `completed` value in batch progress is not incremented yet.
        # If a checkpoint is saved, the saved checkpoint will have an incorrect completed value in batch progress.
        # When we resume from this checkpoint, it will cause batch_idx becoming one step behind global step value in training_step func in modelModule
        trainer.fit_loop.epoch_loop.batch_progress.increment_completed()
        super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
        
        # revert back changes to completed value in batch progress
        trainer.fit_loop.epoch_loop.batch_progress.total.completed -= 1
        trainer.fit_loop.epoch_loop.batch_progress.current.completed -= 1

What version are you seeing the problem on?

v1.9, master

How to reproduce the bug

No response

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

No response

cc @carmocca @justusschock

shuaitang5 avatar Jul 11 '23 19:07 shuaitang5

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions - the Lightning Team!

stale[bot] avatar Aug 12 '23 10:08 stale[bot]

I came across this issue as well. Is there a solution to it?

ordabayevy avatar Nov 30 '23 16:11 ordabayevy

I came across this issue as well. Is there a solution to it?

you can change the checkpoint then its being saved, for example tanking the processed value, so it reruns the batch if the optimizer step does not complete correctly

     checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['total']['completed'] = \
         checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['total']['processed']
     checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed'] = \
         checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['processed']
     checkpoint['loops']['fit_loop']['epoch_loop.state_dict']['_batches_that_stepped'] = \
         checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['processed']```

heth27 avatar Apr 26 '24 22:04 heth27

Definitely still an issue and definitely still open.

docbeaker avatar Jul 01 '24 20:07 docbeaker

Same problem. Hope to see an appropriate solution.

iamlockelightning avatar Oct 12 '24 03:10 iamlockelightning