Checkpoint every_n_steps reruns epoch on restore
Bug description
The checkpoint callback is run before batch_progress.increment_completed() in training_epoch_loop's advance method. Thus in the checkpoint
checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['total']['completed'] e.g. 9
is one smaller than for example
checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['total']['processed'] e.g. 10 or global step.
same for checkpoint['loops']['fit_loop']['epoch_loop.state_dict']['_batches_that_stepped']
Thus when restoring from the checkpoint the batch with batch_idx 9 is run again, even though optimizer step was already done for this batch.
This behavior is unexpected enough to at least warrant a hint in the documentation if not regarded as a bug.
What version are you seeing the problem on?
master
How to reproduce the bug
import os
import math
import time
from typing import Any
import torch
from lightning.fabric.accelerators import find_usable_cuda_devices
from lightning.pytorch.callbacks import ModelCheckpoint, OnExceptionCheckpoint, TQDMProgressBar
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torch import nn
from torch.utils.data import Dataset, DataLoader, Sampler
import lightning.pytorch as pl
from lightning.pytorch import loggers as pl_loggers
class TestModule(nn.Module):
def __init__(self, in_dim=512, out_dim=16):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.simple_layer = nn.Linear(self.in_dim, self.out_dim, bias=True)
def forward(self, input):
return self.simple_layer(input)
class TestBatchSampler(Sampler):
def __init__(self):
super().__init__()
def __len__(self) -> int:
return 1e100
# return len(self.train_allfiles)
def __iter__(self): # -> Iterator[int]:
return self
def __next__(self): # -> Iterator[int]:
return torch.tensor([1])
class TestDataset(Dataset):
def __init__(self, in_dim):
super().__init__()
self.in_dim = in_dim
self.total_len = 512
def __len__(self):
return 1
def __getitem__(self, idx):
return torch.randn(self.in_dim)
class TestDataModule(pl.LightningDataModule):
def __init__(self):
super().__init__()
self.in_dim = 512
self.val_batch_size = 1
def train_dataloader(self):
train_ds = TestDataset(self.in_dim)
train_dl = DataLoader(train_ds, batch_sampler=TestBatchSampler(), num_workers=4, shuffle=False)
return train_dl
def val_dataloader(self):
val_ds = TestDataset(self.in_dim)
val_dl = DataLoader(val_ds, batch_size=self.val_batch_size, num_workers=4, shuffle=False)
return val_dl
class TestLitModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.test_module_obj = TestModule(in_dim=512, out_dim=16)
self.automatic_optimization = False
def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -> None:
print(f"train_batch ended:{batch_idx}")
def on_save_checkpoint(self, checkpoint):
# checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['total']['completed'] = \
# checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['total']['processed']
# checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed'] = \
# checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['processed']
print(f"creating checkpoint")
def validation_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
print(f"validation step called")
return torch.tensor(1.0)
def training_step(self, batch, batch_idx):
print(f"batch_idx: {batch_idx}")
optimizer = self.optimizers()
output = self.test_module_obj(batch)
loss = output.sum()
self.manual_backward(loss)
optimizer.step()
if batch_idx > 25:
raise Exception("This is to stop the program :)")
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.test_module_obj.parameters()
)
return optimizer
if __name__ == '__main__':
test_data_loader = TestDataModule()
test_lit_model = TestLitModel()
checkpoint_dir = 'a_test_logs/'
checkpoint_callback = ModelCheckpoint(
dirpath=checkpoint_dir,
every_n_train_steps=10,
save_top_k=-1, )
exception_checkpoint_callback = OnExceptionCheckpoint(
dirpath=checkpoint_dir,
filename="error"
)
trainer = pl.Trainer(
callbacks=[checkpoint_callback, exception_checkpoint_callback],
max_epochs=-1,
max_steps=400000,
val_check_interval=5
)
trainer.fit(test_lit_model, test_data_loader)
# trainer.fit(test_lit_model,
# datamodule=test_data_loader,
# ckpt_path='a_test_logs/epoch=0-step=10.ckpt')
Error messages and logs
None
Environment
Current environment
#- Lightning Component (Trainer):
#- PyTorch Lightning Version (2.2.3):
More info
No response
I think this is also related to https://github.com/Lightning-AI/pytorch-lightning/issues/18595. The fact that the modelcheckpoint is saved before properly incrementing all parts of the counters seems to lead to a host of unforeseen and hard to debug issues.
I think it is also related to this issue https://github.com/Lightning-AI/pytorch-lightning/issues/18060
I think it is also related to this issue #18060
Yes, its the same issue, I didn't check enough if it already existed