pytorch-lightning icon indicating copy to clipboard operation
pytorch-lightning copied to clipboard

Potential off by 1 error when resuming training of mid-epoch checkpoint

Open ivnle opened this issue 1 year ago • 0 comments

Bug description

During the fit loop, here's a simple log of the global_step and batch_idx values during on_train_batch_start and on_train_batch_end.

[TRAIN STEP START] trainer.global_step=0, batch_idx=0
[TRAIN STEP END] trainer.global_step=1, batch_idx=0
[TRAIN STEP START] trainer.global_step=1, batch_idx=1
[TRAIN STEP END] trainer.global_step=2, batch_idx=1
[TRAIN STEP START] trainer.global_step=2, batch_idx=2
[TRAIN STEP END] trainer.global_step=3, batch_idx=2
[TRAIN STEP START] trainer.global_step=3, batch_idx=3
[TRAIN STEP END] trainer.global_step=4, batch_idx=3
[TRAIN STEP START] trainer.global_step=4, batch_idx=4
[TRAIN STEP END] trainer.global_step=5, batch_idx=4
[VAL STEP START] trainer.global_step=5, batch_idx=0
[VAL STEP START] trainer.global_step=5, batch_idx=1
[VAL STEP START] trainer.global_step=5, batch_idx=2
[VAL STEP START] trainer.global_step=5, batch_idx=3
[VAL STEP START] trainer.global_step=5, batch_idx=4
[TRAIN STEP START] trainer.global_step=5, batch_idx=5
[TRAIN STEP END] trainer.global_step=6, batch_idx=5
[TRAIN STEP START] trainer.global_step=6, batch_idx=6
[TRAIN STEP END] trainer.global_step=7, batch_idx=6
[TRAIN STEP START] trainer.global_step=7, batch_idx=7
[TRAIN STEP END] trainer.global_step=8, batch_idx=7
[TRAIN STEP START] trainer.global_step=8, batch_idx=8
[TRAIN STEP END] trainer.global_step=9, batch_idx=8
[TRAIN STEP START] trainer.global_step=9, batch_idx=9
[TRAIN STEP END] trainer.global_step=10, batch_idx=9
[VAL STEP START] trainer.global_step=10, batch_idx=0
[VAL STEP START] trainer.global_step=10, batch_idx=1
[VAL STEP START] trainer.global_step=10, batch_idx=2
[VAL STEP START] trainer.global_step=10, batch_idx=3
[VAL STEP START] trainer.global_step=10, batch_idx=4
`Trainer.fit` stopped: `max_steps=10` reached.

Notice that global_step and batch_idx are equal during batch_start and global step is 1 greater than batch index for batch_end. Now, if I save a mid-epoch checkpoint after 5 training steps and resume training, I see the following

[TRAIN STEP START] trainer.global_step=5, batch_idx=4
[TRAIN STEP END] trainer.global_step=6, batch_idx=4
[VAL STEP START] trainer.global_step=6, batch_idx=0
[VAL STEP START] trainer.global_step=6, batch_idx=1
[VAL STEP START] trainer.global_step=6, batch_idx=2
[VAL STEP START] trainer.global_step=6, batch_idx=3
[VAL STEP START] trainer.global_step=6, batch_idx=4
[TRAIN STEP START] trainer.global_step=6, batch_idx=5
[TRAIN STEP END] trainer.global_step=7, batch_idx=5
[TRAIN STEP START] trainer.global_step=7, batch_idx=6
[TRAIN STEP END] trainer.global_step=8, batch_idx=6
[TRAIN STEP START] trainer.global_step=8, batch_idx=7
[TRAIN STEP END] trainer.global_step=9, batch_idx=7
[TRAIN STEP START] trainer.global_step=9, batch_idx=8
[TRAIN STEP END] trainer.global_step=10, batch_idx=8
`Trainer.fit` stopped: `max_steps=10` reached.

Now the two values are off by 1 during batch start and off by 2 during batch end. This seems to be an issue because it changes when validation and checkpointing is run. In both runs, I have Trainer(val_check_interval=5, ...) and ModelCheckpoint(every_n_train_steps=5, ...). In the original run, validation happens after 5 and 10 training steps, as expected. In the resumed run, validation only happens once after 6 training steps.

My initial guess is that this is happening because self.batch_progress.increment_completed() in src/lightning/pytorch/loops/training_epoch_loop.py is called after

call._call_callback_hooks(trainer, "on_train_batch_end", batch_output, batch, batch_idx)
call._call_lightning_module_hook(trainer, "on_train_batch_end", batch_output, batch, batch_idx)
trainer._logger_connector.on_batch_end()

so the checkpoint thinks we've only completed global_steps-1 training steps.

What version are you seeing the problem on?

v2.1

How to reproduce the bug

import lightning as L
import torch
import torch.nn.functional as F
from lightning.pytorch.demos import Transformer, WikiText2
from lightning.pytorch.callbacks import ModelCheckpoint, Callback
from torch.utils.data import DataLoader, random_split


class LanguageModel(L.LightningModule):
    def __init__(self, vocab_size):
        super().__init__()
        self.model = Transformer(vocab_size=vocab_size)
        self.model = torch.compile(self.model)

    def training_step(self, batch, batch_idx):
        input, target = batch
        output = self.model(input, target)
        loss = F.nll_loss(output, target.view(-1))
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        input, target = batch
        output = self.model(input, target)
        loss = F.nll_loss(output, target.view(-1))
        self.log("val_loss", loss, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        input, target = batch
        output = self.model(input, target)
        loss = F.nll_loss(output, target.view(-1))
        self.log("test_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=0.1)

    # def on_train_batch_end(self, outputs, batch, batch_idx):
    #     print(f"{self.trainer.global_step=}, {batch_idx=}")


class MyCallback(Callback):
    def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
        print(f"[TRAIN STEP START] {trainer.global_step=}, {batch_idx=}")

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        print(f"[TRAIN STEP END] {trainer.global_step=}, {batch_idx=}")

    def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx):
        print(f"[VAL STEP START] {trainer.global_step=}, {batch_idx=}")


class _ModelCheckpoint(ModelCheckpoint):
    """Modified version of ModelCheckpoint that saves the model after fit completes."""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)


def main():
    L.seed_everything(42)

    # Data
    dataset = WikiText2()

    # Split data in to train, val, test
    n = len(dataset)
    train_dataset, val_dataset, test_dataset = random_split(
        dataset, [n - 200, 100, 100]
    )
    train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=False)
    val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False)
    test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    print(f"{len(train_dataset)=}")

    # Model
    model = LanguageModel(vocab_size=dataset.vocab_size)

    # Callbacks
    checkpoint_callback = ModelCheckpoint(
        every_n_train_steps=5,
        save_top_k=-1,
        enable_version_counter=False,
        verbose=True,
    )
    my_callback = MyCallback()

    # Trainer
    trainer = L.Trainer(
        max_steps=10,
        val_check_interval=5,
        limit_val_batches=5,
        callbacks=[my_callback, checkpoint_callback],
        enable_progress_bar=False,
    )
    trainer.fit(model, train_dataloader, val_dataloader)
    # trainer.test(model, test_dataloader)

    # Resume training from checkpoint
    ckpt_path = "lightning_logs/version_0/checkpoints/epoch=0-step=5.ckpt"
    print(f"Resuming training from checkpoint {ckpt_path}")
    trainer.fit(
        model,
        train_dataloader,
        val_dataloader,
        ckpt_path=ckpt_path,
    )


if __name__ == "__main__":
    main()

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

No response

cc @carmocca @justusschock

ivnle avatar Jan 29 '24 21:01 ivnle