pytorch-lightning icon indicating copy to clipboard operation
pytorch-lightning copied to clipboard

Checkpoint every_n_steps reruns epoch on restore

Open heth27 opened this issue 1 year ago • 3 comments

Bug description

The checkpoint callback is run before batch_progress.increment_completed() in training_epoch_loop's advance method. Thus in the checkpoint checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['total']['completed'] e.g. 9 is one smaller than for example checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['total']['processed'] e.g. 10 or global step. same for checkpoint['loops']['fit_loop']['epoch_loop.state_dict']['_batches_that_stepped']

Thus when restoring from the checkpoint the batch with batch_idx 9 is run again, even though optimizer step was already done for this batch.

This behavior is unexpected enough to at least warrant a hint in the documentation if not regarded as a bug.

What version are you seeing the problem on?

master

How to reproduce the bug

import os
import math
import time
from typing import Any

import torch
from lightning.fabric.accelerators import find_usable_cuda_devices
from lightning.pytorch.callbacks import ModelCheckpoint, OnExceptionCheckpoint, TQDMProgressBar
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torch import nn
from torch.utils.data import Dataset, DataLoader, Sampler

import lightning.pytorch as pl
from lightning.pytorch import loggers as pl_loggers


class TestModule(nn.Module):
    def __init__(self, in_dim=512, out_dim=16):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.simple_layer = nn.Linear(self.in_dim, self.out_dim, bias=True)

    def forward(self, input):
        return self.simple_layer(input)


class TestBatchSampler(Sampler):
    def __init__(self):
        super().__init__()

    def __len__(self) -> int:
        return 1e100
        # return len(self.train_allfiles)

    def __iter__(self):  # -> Iterator[int]:
        return self

    def __next__(self):  # -> Iterator[int]:
        return torch.tensor([1])


class TestDataset(Dataset):
    def __init__(self, in_dim):
        super().__init__()
        self.in_dim = in_dim
        self.total_len = 512

    def __len__(self):
        return 1

    def __getitem__(self, idx):
        return torch.randn(self.in_dim)


class TestDataModule(pl.LightningDataModule):
    def __init__(self):
        super().__init__()
        self.in_dim = 512
        self.val_batch_size = 1

    def train_dataloader(self):
        train_ds = TestDataset(self.in_dim)
        train_dl = DataLoader(train_ds, batch_sampler=TestBatchSampler(), num_workers=4, shuffle=False)
        return train_dl

    def val_dataloader(self):
        val_ds = TestDataset(self.in_dim)
        val_dl = DataLoader(val_ds, batch_size=self.val_batch_size, num_workers=4, shuffle=False)
        return val_dl


class TestLitModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.test_module_obj = TestModule(in_dim=512, out_dim=16)
        self.automatic_optimization = False

    def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -> None:
        print(f"train_batch ended:{batch_idx}")

    def on_save_checkpoint(self, checkpoint):
        # checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['total']['completed'] = \
        #     checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['total']['processed']
        # checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed'] = \
        #     checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['processed']
        print(f"creating checkpoint")

    def validation_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
        print(f"validation step called")
        return torch.tensor(1.0)

    def training_step(self, batch, batch_idx):
        print(f"batch_idx: {batch_idx}")
        optimizer = self.optimizers()

        output = self.test_module_obj(batch)

        loss = output.sum()

        self.manual_backward(loss)

        optimizer.step()

        if batch_idx > 25:
            raise Exception("This is to stop the program :)")

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.test_module_obj.parameters()
        )
        return optimizer


if __name__ == '__main__':
    test_data_loader = TestDataModule()
    test_lit_model = TestLitModel()

    checkpoint_dir = 'a_test_logs/'

    checkpoint_callback = ModelCheckpoint(
        dirpath=checkpoint_dir,
        every_n_train_steps=10,
        save_top_k=-1, )
    exception_checkpoint_callback = OnExceptionCheckpoint(
        dirpath=checkpoint_dir,
        filename="error"
    )
    trainer = pl.Trainer(
        callbacks=[checkpoint_callback, exception_checkpoint_callback],
        max_epochs=-1,
        max_steps=400000,
        val_check_interval=5

    )
    trainer.fit(test_lit_model, test_data_loader)

    # trainer.fit(test_lit_model,
    #             datamodule=test_data_loader,
    #             ckpt_path='a_test_logs/epoch=0-step=10.ckpt')

Error messages and logs

None

Environment

Current environment
#- Lightning Component (Trainer):
#- PyTorch Lightning Version (2.2.3):

More info

No response

heth27 avatar Apr 25 '24 15:04 heth27

I think this is also related to https://github.com/Lightning-AI/pytorch-lightning/issues/18595. The fact that the modelcheckpoint is saved before properly incrementing all parts of the counters seems to lead to a host of unforeseen and hard to debug issues.

johnzielke avatar Apr 26 '24 19:04 johnzielke

I think it is also related to this issue https://github.com/Lightning-AI/pytorch-lightning/issues/18060

ordabayevy avatar Apr 26 '24 21:04 ordabayevy

I think it is also related to this issue #18060

Yes, its the same issue, I didn't check enough if it already existed

heth27 avatar Apr 26 '24 22:04 heth27