litdata
litdata copied to clipboard
Failed to Resume Training w/ CombinedStreamingDataset
🐛 Bug
My training run crashed so I tried to resume it from the previous PyTorch Lightning checkpoint.
When I do so, I get the following error --
[rank5]: Original Traceback (most recent call last):
[rank5]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 252, in _worker_loop
[rank5]: fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
[rank5]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 79, in create_fetcher
[rank5]: return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
[rank5]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 21, in __init__
[rank5]: self.dataset_iter = iter(dataset)
[rank5]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 160, in __iter__
[rank5]: self._iterator = _CombinedDatasetIterator(
[rank5]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 208, in __init__
[rank5]: self._dataset_iters = [iter(dataset) for dataset in datasets]
[rank5]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 208, in <listcomp>
[rank5]: self._dataset_iters = [iter(dataset) for dataset in datasets]
[rank5]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/dataset.py", line 223, in __iter__
[rank5]: self._validate_state_dict()
[rank5]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/dataset.py", line 479, in _validate_state_dict
[rank5]: raise ValueError(
[rank5]: ValueError: The provided `num_samples_yielded` state is greater than the dataset length. Found `46320` instead of `44976`
To Reproduce
Unsure, what a minimal example of this bug is.
Code sample
My dataset is initialized as --
combined_dataset = CombinedStreamingDataset(datasets=train_datasets, iterate_over_all=True)
I do not weight the individual datasets.
Expected behavior
Able to resume training.
Environment
- PyTorch Version (e.g., 1.0): 2.3.1
- OS (e.g., Linux): Linux
- How you installed PyTorch (
conda
,pip
, source): pip - Build command you used (if compiling from source):NA
- Python version: 3.10.12
- CUDA/cuDNN version: 12.1
- GPU models and configuration: 2 nodes, 8 H-100s per node
- Any other relevant information:
Additional context
The bug I showed above emerges when I run on the latest version of LitData. Previously, we were training on an older version of LitData and the following error was cropping up instead --
[rank15]: Traceback (most recent call last):
[rank15]: File "/home/sahil/project/train.py", line 89, in <module>
[rank15]: main(config)
[rank15]: File "/home/sahil/project/train.py", line 69, in main
[rank15]: trainer.fit(model, datamodule=my_data_module, ckpt_path=trainer_checkpoint_path)
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 544, in fit
[rank15]: call._call_and_handle_interrupt(
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 43, in _call_and_handle_interrupt
[rank15]: return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 105, in launch
[rank15]: return function(*args, **kwargs)
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 580, in _fit_impl
[rank15]: self._run(model, ckpt_path=ckpt_path)
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 990, in _run
[rank15]: results = self._run_stage()
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1034, in _run_stage
[rank15]: self.fit_loop.run()
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 205, in run
[rank15]: self.advance()
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 363, in advance
[rank15]: self.epoch_loop.run(self._data_fetcher)
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 140, in run
[rank15]: self.advance(data_fetcher)
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 212, in advance
[rank15]: batch, _, __ = next(data_fetcher)
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/fetchers.py", line 133, in __next__
[rank15]: batch = super().__next__()
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/fetchers.py", line 60, in __next__
[rank15]: batch = next(self.iterator)
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py", line 341, in __next__
[rank15]: out = next(self._iterator)
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py", line 78, in __next__
[rank15]: out[i] = next(self.iterators[i])
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/dataloader.py", line 631, in __iter__
[rank15]: for batch in super().__iter__():
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
[rank15]: data = self._next_data()
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1346, in _next_data
[rank15]: return self._process_data(data)
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1372, in _process_data
[rank15]: data.reraise()
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/_utils.py", line 705, in reraise
[rank15]: raise exception
[rank15]: IndexError: Caught IndexError in DataLoader worker process 0.
[rank15]: Original Traceback (most recent call last):
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 252, in _worker_loop
[rank15]: fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 79, in create_fetcher
[rank15]: return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 21, in __init__
[rank15]: self.dataset_iter = iter(dataset)
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 155, in __iter__
[rank15]: self._iterator = _CombinedDatasetIterator(
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 203, in __init__
[rank15]: self._dataset_iters = [iter(dataset) for dataset in datasets]
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 203, in <listcomp>
[rank15]: self._dataset_iters = [iter(dataset) for dataset in datasets]
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/dataset.py", line 236, in __iter__
[rank15]: self._resume(workers_chunks, workers_intervals)
[rank15]: File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/dataset.py", line 308, in _resume
[rank15]: interval = self.worker_intervals[self.chunk_index]
[rank15]: IndexError: list index out of range
Also note that this error is happening several epochs into training with data that is stored locally (not being streamed from an S3 blob store).