litdata
litdata copied to clipboard
Resuming StreamingDataloader with num_workers=0 fails
Bug description
Using a StreamingDataloader with num_workers=0 works, but resuming the state does not. There is an explicit length check for the state that fails.
Using num_workers=0 is maybe not very meaningful for real applications, but it might be good for debugging and testing purposes. Alternatively, if that's difficult to support, then StreamingDataloader could just force having num_workers>=1. I think we should do something about it, since 0 is the default for the dataloader and users might forget to set it and then run into this error which could be confusing them.
What version are you seeing the problem on?
master
How to reproduce the bug
import torch
def run():
checkpoint_path = "checkpoint.pt"
# Save a checkpoint
train_dataloader = create_dataloader()
train_iterator = iter(train_dataloader)
next(train_iterator)
next(train_iterator)
torch.save(train_dataloader.state_dict(), checkpoint_path)
# Reset and attempt resume
train_dataloader = create_dataloader()
state = {"train_dataloader": train_dataloader}
train_dataloader.load_state_dict(torch.load(checkpoint_path))
train_iterator = iter(train_dataloader)
next(train_iterator)
next(train_iterator)
def create_dataloader():
from lightning.data import StreamingDataset, CombinedStreamingDataset, StreamingDataLoader
from lightning.data.streaming.item_loader import TokensLoader
train_datasets = [
StreamingDataset(
input_dir="/teamspace/s3_connections/tinyllama-template/slimpajama/train",
item_loader=TokensLoader(block_size=4),
),
StreamingDataset(
input_dir="/teamspace/s3_connections/tinyllama-template/starcoder",
item_loader=TokensLoader(block_size=4),
),
]
combined_dataset = CombinedStreamingDataset(datasets=train_datasets)
train_dataloader = StreamingDataLoader(combined_dataset, batch_size=4, num_workers=0) # <--- BUG WHEN NUM WORKERS=0
return train_dataloader
if __name__ == "__main__":
run()
Error messages and logs
Traceback (most recent call last):
File "/teamspace/studios/this_studio/repro_worker.py", line 50, in <module>
run()
File "/teamspace/studios/this_studio/repro_worker.py", line 25, in run
next(train_iterator)
File "/teamspace/studios/this_studio/lightning/src/lightning/data/streaming/dataloader.py", line 432, in __iter__
for batch in super().__iter__():
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 438, in __iter__
return self._get_iterator()
File "/teamspace/studios/this_studio/lightning/src/lightning/data/streaming/dataloader.py", line 504, in _get_iterator
return _SingleProcessDataLoaderIter(self)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 669, in __init__
self._dataset_fetcher = _DatasetKind.create_fetcher(
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 79, in create_fetcher
return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 21, in __init__
self.dataset_iter = iter(dataset)
File "/teamspace/studios/this_studio/lightning/src/lightning/data/streaming/combined.py", line 83, in __iter__
self._iterator = _CombinedDatasetIterator(
File "/teamspace/studios/this_studio/lightning/src/lightning/data/streaming/combined.py", line 126, in __init__
self._dataset_iters = [iter(dataset) for dataset in datasets]
File "/teamspace/studios/this_studio/lightning/src/lightning/data/streaming/combined.py", line 126, in <listcomp>
self._dataset_iters = [iter(dataset) for dataset in datasets]
File "/teamspace/studios/this_studio/lightning/src/lightning/data/streaming/dataset.py", line 146, in __iter__
self._validate_state_dict()
File "/teamspace/studios/this_studio/lightning/src/lightning/data/streaming/dataset.py", line 328, in _validate_state_dict
raise ValueError(
ValueError: The provided `num_workers` state doesn't match the current one. Found `1` instead of `0`.
Environment
Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0): master (2.2dev)
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):
More info
No response
Moved from https://github.com/Lightning-AI/pytorch-lightning/issues/19335, submitted by @awaelchli
Following up on this, what is the recommended practice/solution? I was able to load my checkpoint and manually adjust it from num_workers=0 to 1 and save it again, to get it to pass the check when it loads the state_dict but wanted to know if there's a better work around
Hi, I've also encountered this issue but with non-zero num_workers. Even more bizarre is I get this 75% of the way through a validation epoch when training with the latest pytorch lightning trainer, and AFTER successfully getting through the first validation epoch & even saving a .ckpt!
any ideas/updates on what's going on here?
Traceback (most recent call last):
17:42:49 File "/valohai/repository/ml/polle/train_polle.py", line 184, in train
17:42:49 trainer.fit(lightning_module, datamodule=datamodule, ckpt_path=polle_path)
17:42:49 File "/usr/local/lib/python3.11/dist-packages/lightning/pytorch/trainer/trainer.py", line 543, in fit
17:42:49 call._call_and_handle_interrupt(
17:42:49 File "/usr/local/lib/python3.11/dist-packages/lightning/pytorch/trainer/call.py", line 43, in _call_and_handle_interrupt
17:42:49 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
17:42:49 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
17:42:49 File "/usr/local/lib/python3.11/dist-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
17:42:49 return function(*args, **kwargs)
17:42:49 ^^^^^^^^^^^^^^^^^^^^^^^^^
17:42:49 File "/usr/local/lib/python3.11/dist-packages/lightning/pytorch/trainer/trainer.py", line 579, in _fit_impl
17:42:49 self._run(model, ckpt_path=ckpt_path)
17:42:49 File "/usr/local/lib/python3.11/dist-packages/lightning/pytorch/trainer/trainer.py", line 986, in _run
17:42:49 results = self._run_stage()
17:42:49 ^^^^^^^^^^^^^^^^^
17:42:49 File "/usr/local/lib/python3.11/dist-packages/lightning/pytorch/trainer/trainer.py", line 1030, in _run_stage
17:42:49 self.fit_loop.run()
17:42:49 File "/usr/local/lib/python3.11/dist-packages/lightning/pytorch/loops/fit_loop.py", line 205, in run
17:42:49 self.advance()
17:42:49 File "/usr/local/lib/python3.11/dist-packages/lightning/pytorch/loops/fit_loop.py", line 363, in advance
17:42:49 self.epoch_loop.run(self._data_fetcher)
17:42:49 File "/usr/local/lib/python3.11/dist-packages/lightning/pytorch/loops/training_epoch_loop.py", line 140, in run
17:42:49 self.advance(data_fetcher)
17:42:49 File "/usr/local/lib/python3.11/dist-packages/lightning/pytorch/loops/training_epoch_loop.py", line 212, in advance
17:42:49 batch, _, __ = next(data_fetcher)
17:42:49 ^^^^^^^^^^^^^^^^^^
17:42:49 File "/usr/local/lib/python3.11/dist-packages/lightning/pytorch/loops/fetchers.py", line 133, in __next__
17:42:49 batch = super().__next__()
17:42:49 ^^^^^^^^^^^^^^^^^^
17:42:49 File "/usr/local/lib/python3.11/dist-packages/lightning/pytorch/loops/fetchers.py", line 60, in __next__
17:42:49 batch = next(self.iterator)
17:42:49 ^^^^^^^^^^^^^^^^^^^
17:42:49 File "/usr/local/lib/python3.11/dist-packages/lightning/pytorch/utilities/combined_loader.py", line 341, in __next__
17:42:49 out = next(self._iterator)
17:42:49 ^^^^^^^^^^^^^^^^^^^^
17:42:49 File "/usr/local/lib/python3.11/dist-packages/lightning/pytorch/utilities/combined_loader.py", line 78, in __next__
17:42:49 out[i] = next(self.iterators[i])
17:42:49 ^^^^^^^^^^^^^^^^^^^^^^^
17:42:49 File "/usr/local/lib/python3.11/dist-packages/litdata/streaming/dataloader.py", line 620, in __iter__
17:42:49 for batch in super().__iter__():
17:42:49 File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 631, in __next__
17:42:49 data = self._next_data()
17:42:49 ^^^^^^^^^^^^^^^^^
17:42:49 File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1346, in _next_data
17:42:49 return self._process_data(data)
17:42:49 ^^^^^^^^^^^^^^^^^^^^^^^^
17:42:49 File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1372, in _process_data
17:42:49 data.reraise()
17:42:49 File "/usr/local/lib/python3.11/dist-packages/torch/_utils.py", line 705, in reraise
17:42:49 raise exception
17:42:49ValueError: Caught ValueError in DataLoader worker process 16.
17:42:49Original Traceback (most recent call last):
17:42:49 File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/_utils/worker.py", line 252, in _worker_loop
17:42:49 fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
17:42:49 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
17:42:49 File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 79, in create_fetcher
17:42:49 return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
17:42:49 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
17:42:49 File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/_utils/fetch.py", line 21, in __init__
17:42:49 self.dataset_iter = iter(dataset)
17:42:49 ^^^^^^^^^^^^^
17:42:49 File "/usr/local/lib/python3.11/dist-packages/litdata/streaming/dataset.py", line 199, in __iter__
17:42:49 self._validate_state_dict()
17:42:49 File "/usr/local/lib/python3.11/dist-packages/litdata/streaming/dataset.py", line 388, in _validate_state_dict
17:42:49 raise ValueError(
17:42:49ValueError: The provided `num_workers` state doesn't match the current one. Found `24` instead of `16`.
Hey @ukasschmit, LitData doesn't support changing the number of workers when resuming. Can you restart with 16 workers?
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.