Resume from mid steps inside an epoch
Description & Motivation
LLMs are trained on growing size of corpora, only resume by epochs is not enough, as models may only be trained on a few epochs and one epoch may take a few days to train. Currently lightning prints a warning message as follows when trying to resume from mid steps inside an epoch and asks for a resumable dataloader:
However, I can't find any examples resuming from mid steps in docs/blogs(maybe my bad). And it's quite strange to me to implement a dataloader with state_dict/load_state_dict methods, as dataloader cannot hold states by design, it's the iterator derived from dataloader that is resumable and should hold the necessary states. Besides, we may not need the state_dict and load_state_dict methods to save/load dataloaders, as the epoch/step idx hold enough message to restore the necessary training batch state.
I proposed a possible hackin that can work around this issue, taking inspirations from hugging face train script.
Pitch
No response
Alternatives
Here is an ugly hackin(by callbacks in LightningModule) now I used to resume the specific batch:
class SkipBatchSampler(BatchSampler):
r"""
Modified from huggingface accelerate/data_loader.py
"""
def __init__(self, batch_sampler: BatchSampler, skip_batches: int = 0):
self.batch_sampler = batch_sampler
self.skip_batches = skip_batches
def __iter__(self):
for i, batch in enumerate(self.batch_sampler):
if i >= self.skip_batches:
yield batch
def __len__(self):
return len(self.batch_sampler) # - self.skip_batches, due to in loops.training_epoch_loop.py on_run_start(), which will set fetched value, ugly hackin here
_PYTORCH_DATALOADER_KWARGS_SUBSTITUTE = {
"num_workers": 0,
"collate_fn": None,
"pin_memory": False,
"timeout": 0,
"worker_init_fn": None,
"multiprocessing_context": None,
"generator": None,
"prefetch_factor": 2,
"persistent_workers": False,
}
def resume_dataloader(dataloader: DataLoader, steps_in_epoch: int) -> DataLoader:
r"""
We don't want to directly iterate on dataloader (which will cause data
preprocessing overhead), we iterate on sampler
"""
#TODO, currently not support iterable dataset, DataLoaderDispatcher, DataLoaderShard
assert not isinstance(dataloader.dataset, IterableDataset)
new_batch_sampler = SkipBatchSampler(dataloader.batch_sampler, steps_in_epoch)
kwargs = {k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS_SUBSTITUTE[k])
for k in _PYTORCH_DATALOADER_KWARGS_SUBSTITUTE}
return DataLoader(dataloader.dataset, batch_sampler=new_batch_sampler, **kwargs)
class LightningModel(L.LightningModule):
# hackins
def on_train_start(self):
self.restarted_run = False
def on_train_epoch_start(self):
# modify train dataloader
if self.trainer.fit_loop.restarting:
self.restarted_run = True
self.trainer.fit_loop.backup_dataloaders = self.trainer.fit_loop._combined_loader.flattened
self.trainer.fit_loop._combined_loader.flattened = [
resume_dataloader(dl, self.trainer.fit_loop.epoch_loop.batch_progress.current.completed)
for dl in self.trainer.fit_loop._combined_loader.flattened
]
# need to call iter to rebuild data_fetcher.iterator (which is originally
# set in setup_data)
self.trainer.fit_loop._data_fetcher.setup(self.trainer.fit_loop._combined_loader)
with isolate_rng():
iter(self.trainer.fit_loop._data_fetcher)
else:
if self.restarted_run:
self.trainer.fit_loop._combined_loader.flattened = self.trainer.fit_loop.backup_dataloaders
# set epoch again, cause the epoch right after restarting one will have problems
for dl in self.trainer.fit_loop._combined_loader.flattened:
_set_sampler_epoch(dl, self.trainer.current_epoch)
self.trainer.fit_loop._data_fetcher.setup(self.trainer.fit_loop._combined_loader)
# no need to rebuild iterator, already in epoch_loop.on_run_start
# iter(self.trainer.fit_loop._data_fetcher)
Additional context
No response
cc @borda
It seems like StatefulDataLoader from torchdata might help here. However, if I replace my old data loader with StatefulDataLoader, I cannot find a corresponding entry in the saved checkpoint. The warning doesn't appear, either.
I am experiencing the same problem when resuming training with a huge data scale in one epoch. I would agree to support the skipping batch logic as in hugging face train script.