litdata icon indicating copy to clipboard operation
litdata copied to clipboard

Failed to Resume Training w/ CombinedStreamingDataset

Open schopra8 opened this issue 5 months ago • 1 comments

🐛 Bug

My training run crashed so I tried to resume it from the previous PyTorch Lightning checkpoint.

When I do so, I get the following error --

[rank5]: Original Traceback (most recent call last):
[rank5]:   File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 252, in _worker_loop
[rank5]:     fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
[rank5]:   File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 79, in create_fetcher
[rank5]:     return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
[rank5]:   File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 21, in __init__
[rank5]:     self.dataset_iter = iter(dataset)
[rank5]:   File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 160, in __iter__
[rank5]:     self._iterator = _CombinedDatasetIterator(
[rank5]:   File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 208, in __init__
[rank5]:     self._dataset_iters = [iter(dataset) for dataset in datasets]
[rank5]:   File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 208, in <listcomp>
[rank5]:     self._dataset_iters = [iter(dataset) for dataset in datasets]
[rank5]:   File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/dataset.py", line 223, in __iter__
[rank5]:     self._validate_state_dict()
[rank5]:   File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/dataset.py", line 479, in _validate_state_dict
[rank5]:     raise ValueError(
[rank5]: ValueError: The provided `num_samples_yielded` state is greater than the dataset length. Found `46320` instead of `44976`

To Reproduce

Unsure, what a minimal example of this bug is.

Code sample

My dataset is initialized as -- combined_dataset = CombinedStreamingDataset(datasets=train_datasets, iterate_over_all=True)

I do not weight the individual datasets.

Expected behavior

Able to resume training.

Environment

  • PyTorch Version (e.g., 1.0): 2.3.1
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source):NA
  • Python version: 3.10.12
  • CUDA/cuDNN version: 12.1
  • GPU models and configuration: 2 nodes, 8 H-100s per node
  • Any other relevant information:

Additional context

The bug I showed above emerges when I run on the latest version of LitData. Previously, we were training on an older version of LitData and the following error was cropping up instead --

[rank15]: Traceback (most recent call last):
[rank15]:  File "/home/sahil/project/train.py", line 89, in <module>
[rank15]:   main(config)
[rank15]:  File "/home/sahil/project/train.py", line 69, in main
[rank15]:   trainer.fit(model, datamodule=my_data_module, ckpt_path=trainer_checkpoint_path)
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 544, in fit
[rank15]:   call._call_and_handle_interrupt(
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 43, in _call_and_handle_interrupt
[rank15]:   return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 105, in launch
[rank15]:   return function(*args, **kwargs)
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 580, in _fit_impl
[rank15]:   self._run(model, ckpt_path=ckpt_path)
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 990, in _run
[rank15]:   results = self._run_stage()
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1034, in _run_stage
[rank15]:   self.fit_loop.run()
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 205, in run
[rank15]:   self.advance()
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 363, in advance
[rank15]:   self.epoch_loop.run(self._data_fetcher)
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 140, in run
[rank15]:   self.advance(data_fetcher)
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 212, in advance
[rank15]:   batch, _, __ = next(data_fetcher)
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/fetchers.py", line 133, in __next__
[rank15]:   batch = super().__next__()
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/fetchers.py", line 60, in __next__
[rank15]:   batch = next(self.iterator)
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py", line 341, in __next__
[rank15]:   out = next(self._iterator)
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py", line 78, in __next__
[rank15]:   out[i] = next(self.iterators[i])
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/dataloader.py", line 631, in __iter__
[rank15]:   for batch in super().__iter__():
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
[rank15]:   data = self._next_data()
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1346, in _next_data
[rank15]:   return self._process_data(data)
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1372, in _process_data
[rank15]:   data.reraise()
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/_utils.py", line 705, in reraise
[rank15]:   raise exception
[rank15]: IndexError: Caught IndexError in DataLoader worker process 0.
[rank15]: Original Traceback (most recent call last):
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 252, in _worker_loop
[rank15]:   fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 79, in create_fetcher
[rank15]:   return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 21, in __init__
[rank15]:   self.dataset_iter = iter(dataset)
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 155, in __iter__
[rank15]:   self._iterator = _CombinedDatasetIterator(
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 203, in __init__
[rank15]:   self._dataset_iters = [iter(dataset) for dataset in datasets]
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 203, in <listcomp>
[rank15]:   self._dataset_iters = [iter(dataset) for dataset in datasets]
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/dataset.py", line 236, in __iter__
[rank15]:   self._resume(workers_chunks, workers_intervals)
[rank15]:  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/dataset.py", line 308, in _resume
[rank15]:   interval = self.worker_intervals[self.chunk_index]
[rank15]: IndexError: list index out of range

Also note that this error is happening several epochs into training with data that is stored locally (not being streamed from an S3 blob store).

schopra8 avatar Sep 05 '24 02:09 schopra8