pytorch-lightning
pytorch-lightning copied to clipboard
Potential off by 1 error when resuming training of mid-epoch checkpoint
Bug description
During the fit loop, here's a simple log of the global_step
and batch_idx
values during on_train_batch_start
and on_train_batch_end
.
[TRAIN STEP START] trainer.global_step=0, batch_idx=0
[TRAIN STEP END] trainer.global_step=1, batch_idx=0
[TRAIN STEP START] trainer.global_step=1, batch_idx=1
[TRAIN STEP END] trainer.global_step=2, batch_idx=1
[TRAIN STEP START] trainer.global_step=2, batch_idx=2
[TRAIN STEP END] trainer.global_step=3, batch_idx=2
[TRAIN STEP START] trainer.global_step=3, batch_idx=3
[TRAIN STEP END] trainer.global_step=4, batch_idx=3
[TRAIN STEP START] trainer.global_step=4, batch_idx=4
[TRAIN STEP END] trainer.global_step=5, batch_idx=4
[VAL STEP START] trainer.global_step=5, batch_idx=0
[VAL STEP START] trainer.global_step=5, batch_idx=1
[VAL STEP START] trainer.global_step=5, batch_idx=2
[VAL STEP START] trainer.global_step=5, batch_idx=3
[VAL STEP START] trainer.global_step=5, batch_idx=4
[TRAIN STEP START] trainer.global_step=5, batch_idx=5
[TRAIN STEP END] trainer.global_step=6, batch_idx=5
[TRAIN STEP START] trainer.global_step=6, batch_idx=6
[TRAIN STEP END] trainer.global_step=7, batch_idx=6
[TRAIN STEP START] trainer.global_step=7, batch_idx=7
[TRAIN STEP END] trainer.global_step=8, batch_idx=7
[TRAIN STEP START] trainer.global_step=8, batch_idx=8
[TRAIN STEP END] trainer.global_step=9, batch_idx=8
[TRAIN STEP START] trainer.global_step=9, batch_idx=9
[TRAIN STEP END] trainer.global_step=10, batch_idx=9
[VAL STEP START] trainer.global_step=10, batch_idx=0
[VAL STEP START] trainer.global_step=10, batch_idx=1
[VAL STEP START] trainer.global_step=10, batch_idx=2
[VAL STEP START] trainer.global_step=10, batch_idx=3
[VAL STEP START] trainer.global_step=10, batch_idx=4
`Trainer.fit` stopped: `max_steps=10` reached.
Notice that global_step and batch_idx are equal during batch_start and global step is 1 greater than batch index for batch_end. Now, if I save a mid-epoch checkpoint after 5 training steps and resume training, I see the following
[TRAIN STEP START] trainer.global_step=5, batch_idx=4
[TRAIN STEP END] trainer.global_step=6, batch_idx=4
[VAL STEP START] trainer.global_step=6, batch_idx=0
[VAL STEP START] trainer.global_step=6, batch_idx=1
[VAL STEP START] trainer.global_step=6, batch_idx=2
[VAL STEP START] trainer.global_step=6, batch_idx=3
[VAL STEP START] trainer.global_step=6, batch_idx=4
[TRAIN STEP START] trainer.global_step=6, batch_idx=5
[TRAIN STEP END] trainer.global_step=7, batch_idx=5
[TRAIN STEP START] trainer.global_step=7, batch_idx=6
[TRAIN STEP END] trainer.global_step=8, batch_idx=6
[TRAIN STEP START] trainer.global_step=8, batch_idx=7
[TRAIN STEP END] trainer.global_step=9, batch_idx=7
[TRAIN STEP START] trainer.global_step=9, batch_idx=8
[TRAIN STEP END] trainer.global_step=10, batch_idx=8
`Trainer.fit` stopped: `max_steps=10` reached.
Now the two values are off by 1 during batch start and off by 2 during batch end. This seems to be an issue because it changes when validation and checkpointing is run. In both runs, I have Trainer(val_check_interval=5, ...)
and ModelCheckpoint(every_n_train_steps=5, ...)
. In the original run, validation happens after 5 and 10 training steps, as expected. In the resumed run, validation only happens once after 6 training steps.
My initial guess is that this is happening because self.batch_progress.increment_completed()
in src/lightning/pytorch/loops/training_epoch_loop.py
is called after
call._call_callback_hooks(trainer, "on_train_batch_end", batch_output, batch, batch_idx)
call._call_lightning_module_hook(trainer, "on_train_batch_end", batch_output, batch, batch_idx)
trainer._logger_connector.on_batch_end()
so the checkpoint thinks we've only completed global_steps-1
training steps.
What version are you seeing the problem on?
v2.1
How to reproduce the bug
import lightning as L
import torch
import torch.nn.functional as F
from lightning.pytorch.demos import Transformer, WikiText2
from lightning.pytorch.callbacks import ModelCheckpoint, Callback
from torch.utils.data import DataLoader, random_split
class LanguageModel(L.LightningModule):
def __init__(self, vocab_size):
super().__init__()
self.model = Transformer(vocab_size=vocab_size)
self.model = torch.compile(self.model)
def training_step(self, batch, batch_idx):
input, target = batch
output = self.model(input, target)
loss = F.nll_loss(output, target.view(-1))
self.log("train_loss", loss, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
input, target = batch
output = self.model(input, target)
loss = F.nll_loss(output, target.view(-1))
self.log("val_loss", loss, prog_bar=True)
return loss
def test_step(self, batch, batch_idx):
input, target = batch
output = self.model(input, target)
loss = F.nll_loss(output, target.view(-1))
self.log("test_loss", loss, prog_bar=True)
return loss
def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=0.1)
# def on_train_batch_end(self, outputs, batch, batch_idx):
# print(f"{self.trainer.global_step=}, {batch_idx=}")
class MyCallback(Callback):
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
print(f"[TRAIN STEP START] {trainer.global_step=}, {batch_idx=}")
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
print(f"[TRAIN STEP END] {trainer.global_step=}, {batch_idx=}")
def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx):
print(f"[VAL STEP START] {trainer.global_step=}, {batch_idx=}")
class _ModelCheckpoint(ModelCheckpoint):
"""Modified version of ModelCheckpoint that saves the model after fit completes."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def main():
L.seed_everything(42)
# Data
dataset = WikiText2()
# Split data in to train, val, test
n = len(dataset)
train_dataset, val_dataset, test_dataset = random_split(
dataset, [n - 200, 100, 100]
)
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=False)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
print(f"{len(train_dataset)=}")
# Model
model = LanguageModel(vocab_size=dataset.vocab_size)
# Callbacks
checkpoint_callback = ModelCheckpoint(
every_n_train_steps=5,
save_top_k=-1,
enable_version_counter=False,
verbose=True,
)
my_callback = MyCallback()
# Trainer
trainer = L.Trainer(
max_steps=10,
val_check_interval=5,
limit_val_batches=5,
callbacks=[my_callback, checkpoint_callback],
enable_progress_bar=False,
)
trainer.fit(model, train_dataloader, val_dataloader)
# trainer.test(model, test_dataloader)
# Resume training from checkpoint
ckpt_path = "lightning_logs/version_0/checkpoints/epoch=0-step=5.ckpt"
print(f"Resuming training from checkpoint {ckpt_path}")
trainer.fit(
model,
train_dataloader,
val_dataloader,
ckpt_path=ckpt_path,
)
if __name__ == "__main__":
main()
Error messages and logs
# Error messages and logs here please
Environment
Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):
More info
No response
cc @carmocca @justusschock