Failed to Resume Training w/ CombinedStreamingDataset
🐛 Bug
My training run crashed so I tried to resume it from the previous PyTorch Lightning checkpoint.
When I do so, I get the following error --
[rank5]: Original Traceback (most recent call last):
[rank5]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 252, in _worker_loop
[rank5]: fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
[rank5]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 79, in create_fetcher
[rank5]: return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
[rank5]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 21, in __init__
[rank5]: self.dataset_iter = iter(dataset)
[rank5]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 160, in __iter__
[rank5]: self._iterator = _CombinedDatasetIterator(
[rank5]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 208, in __init__
[rank5]: self._dataset_iters = [iter(dataset) for dataset in datasets]
[rank5]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 208, in <listcomp>
[rank5]: self._dataset_iters = [iter(dataset) for dataset in datasets]
[rank5]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/dataset.py", line 223, in __iter__
[rank5]: self._validate_state_dict()
[rank5]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/dataset.py", line 479, in _validate_state_dict
[rank5]: raise ValueError(
[rank5]: ValueError: The provided `num_samples_yielded` state is greater than the dataset length. Found `46320` instead of `44976`
To Reproduce
Unsure, what a minimal example of this bug is.
Code sample
My dataset is initialized as --
combined_dataset = CombinedStreamingDataset(datasets=train_datasets, iterate_over_all=True)
I do not weight the individual datasets.
Expected behavior
Able to resume training.
Environment
- PyTorch Version (e.g., 1.0): 2.3.1
- OS (e.g., Linux): Linux
- How you installed PyTorch (
conda,pip, source): pip - Build command you used (if compiling from source):NA
- Python version: 3.10.12
- CUDA/cuDNN version: 12.1
- GPU models and configuration: 2 nodes, 8 H-100s per node
- Any other relevant information:
Additional context
The bug I showed above emerges when I run on the latest version of LitData. Previously, we were training on an older version of LitData and the following error was cropping up instead --
[rank15]: Traceback (most recent call last):
[rank15]: File "/home/sahil/project/train.py", line 89, in <module>
[rank15]: main(config)
[rank15]: File "/home/sahil/project/train.py", line 69, in main
[rank15]: trainer.fit(model, datamodule=my_data_module, ckpt_path=trainer_checkpoint_path)
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 544, in fit
[rank15]: call._call_and_handle_interrupt(
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 43, in _call_and_handle_interrupt
[rank15]: return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 105, in launch
[rank15]: return function(*args, **kwargs)
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 580, in _fit_impl
[rank15]: self._run(model, ckpt_path=ckpt_path)
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 990, in _run
[rank15]: results = self._run_stage()
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1034, in _run_stage
[rank15]: self.fit_loop.run()
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 205, in run
[rank15]: self.advance()
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 363, in advance
[rank15]: self.epoch_loop.run(self._data_fetcher)
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 140, in run
[rank15]: self.advance(data_fetcher)
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 212, in advance
[rank15]: batch, _, __ = next(data_fetcher)
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/fetchers.py", line 133, in __next__
[rank15]: batch = super().__next__()
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/fetchers.py", line 60, in __next__
[rank15]: batch = next(self.iterator)
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py", line 341, in __next__
[rank15]: out = next(self._iterator)
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py", line 78, in __next__
[rank15]: out[i] = next(self.iterators[i])
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/dataloader.py", line 631, in __iter__
[rank15]: for batch in super().__iter__():
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
[rank15]: data = self._next_data()
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1346, in _next_data
[rank15]: return self._process_data(data)
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1372, in _process_data
[rank15]: data.reraise()
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/_utils.py", line 705, in reraise
[rank15]: raise exception
[rank15]: IndexError: Caught IndexError in DataLoader worker process 0.
[rank15]: Original Traceback (most recent call last):
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 252, in _worker_loop
[rank15]: fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 79, in create_fetcher
[rank15]: return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 21, in __init__
[rank15]: self.dataset_iter = iter(dataset)
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 155, in __iter__
[rank15]: self._iterator = _CombinedDatasetIterator(
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 203, in __init__
[rank15]: self._dataset_iters = [iter(dataset) for dataset in datasets]
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 203, in <listcomp>
[rank15]: self._dataset_iters = [iter(dataset) for dataset in datasets]
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/dataset.py", line 236, in __iter__
[rank15]: self._resume(workers_chunks, workers_intervals)
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/dataset.py", line 308, in _resume
[rank15]: interval = self.worker_intervals[self.chunk_index]
[rank15]: IndexError: list index out of range
Also note that this error is happening several epochs into training with data that is stored locally (not being streamed from an S3 blob store).
Hi @schopra8, Thank you for bringing this issue to our attention!
Also, I have a draft PR (#362) that addresses the same issues (#331 , #363) . Initially, I was working on solving the issue for both cases (with and without weights), but I'll focus on the without weights case first to speed things up.
In the meantime, please feel free to take a look at the PR and drop any feedback or suggestions. Always open to any thoughts you might have!
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.
Closing this as it has been fixed by #507. Please feel free to reopen the issue if it still persists.