calling iter twice messes up dataloaders with queues
Bug description
This bug has reappeared https://github.com/Lightning-AI/pytorch-lightning/issues/18414
We now call iter() twice in different places:
- https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/pytorch/loops/fit_loop.py#L263
- https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/pytorch/loops/training_epoch_loop.py#L171C1-L172C1
What version are you seeing the problem on?
v2.1
How to reproduce the bug
import multiprocessing as mp
from queue import Queue
from typing import Iterator
import numpy as np
from lightning import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from torch.utils.data import DataLoader, IterableDataset
class QueueDataset(IterableDataset):
def __init__(self, queue: Queue) -> None:
super().__init__()
self.queue = queue
def __iter__(self) -> Iterator:
for k in range(5):
print(f"getting {k}")
tensor, index = self.queue.get(timeout=10)
print(f"got {index}")
yield tensor
if __name__ == "__main__":
q = mp.Queue()
arr = np.random.random([1, 32]).astype(np.float32)
for ind in range(10):
q.put((arr, ind))
max_epoch = 1
dataloader = DataLoader(QueueDataset(q), num_workers=1, batch_size=None, persistent_workers=True)
trainer = Trainer(max_epochs=max_epoch, enable_progress_bar=False, devices=1)
trainer.fit(BoringModel(), dataloader)
trainer.save_checkpoint("model.ckpt")
# q now has the next 5 elems in
# resuming training we will hit the double iter() issue
dataloader = DataLoader(QueueDataset(q), num_workers=1, batch_size=None, persistent_workers=True)
trainer = Trainer(max_epochs=max_epoch + 1, enable_progress_bar=False, devices=1)
trainer.fit(BoringModel(), dataloader, ckpt_path="model.ckpt")
Error messages and logs
relevant logs are:
# first epoch all good
getting 0
got 0
getting 1
got 1
getting 2
got 2
getting 3
got 3
getting 4
got 4
# second epoch we start getting from the queue twice!
# from fit loop iter()
getting 0
got 5
getting 1
got 6
getting 2
got 7
# from training_epoch loop iter()
getting 0
got 8
getting 1
got 9
getting 2
Environment
lighting==2.1.4
More info
No response
cc @justusschock @awaelchli @carmocca
This condition here is meant to prevent the iter() from getting called a second time, because in this case restarting should be True.
https://github.com/Lightning-AI/pytorch-lightning/blob/47c8f4cba089a78fa3fe31dcac6a43416bc13820/src/lightning/pytorch/loops/training_epoch_loop.py#L169-L171
But it isn't. The problem is that the fit loop sets restarting=False even though we are resuming, due to the logic here:
https://github.com/Lightning-AI/pytorch-lightning/blob/47c8f4cba089a78fa3fe31dcac6a43416bc13820/src/lightning/pytorch/loops/fit_loop.py#L123-L128
This is tricky to solve @carmocca. The logic probably needs to be lifted up into the fit loop before epoch_loop.run(), with a different conditioning that does not rely on restarting.
I didn't look too deeply. Couldn't we check restarting too for the FitLoop's iter call? We have a lot of tests around this so If a solution passes them we should be good.
The problem in the restarting property is self._iteration_based_training() is False
Also since this has appeared twice now, and its the sort of bug which is hard to track down could we add a test like my example?