datasets icon indicating copy to clipboard operation
datasets copied to clipboard

Save and resume the state of a DataLoader

Open lhoestq opened this issue 2 years ago • 18 comments

It would be nice when using datasets with a PyTorch DataLoader to be able to resume a training from a DataLoader state (e.g. to resume a training that crashed)

What I have in mind (but lmk if you have other ideas or comments):

For map-style datasets, this requires to have a PyTorch Sampler state that can be saved and reloaded per node and worker.

For iterable datasets, this requires to save the state of the dataset iterator, which includes:

  • the current shard idx and row position in the current shard
  • the epoch number
  • the rng state
  • the shuffle buffer

Right now you can already resume the data loading of an iterable dataset by using IterableDataset.skip but it takes a lot of time because it re-iterates on all the past data until it reaches the resuming point.

cc @stas00 @sgugger

lhoestq avatar Jan 23 '23 10:01 lhoestq

Something that'd be nice to have is "manual update of state". One of the learning from training LLMs is the ability to skip some batches whenever we notice huge spike might be handy.

thomasw21 avatar Jan 23 '23 11:01 thomasw21

Your outline spec is very sound and clear, @lhoestq - thank you!

@thomasw21, indeed that would be a wonderful extra feature. In Megatron-Deepspeed we manually drained the dataloader for the range we wanted. I wasn't very satisfied with the way we did it, since its behavior would change if you were to do multiple range skips. I think it should remember all the ranges it skipped and not just skip the last range - since otherwise the data is inconsistent (but we probably should discuss this in a separate issue not to derail this much bigger one).

stas00 avatar Jan 24 '23 01:01 stas00

Hi there! I think this is a critical issue and have an urgent need for it, in my attempt to train on a super large-scale dataset using datasets. It is impossible to resume a time-consuming (like one month) experiment by iterating all seen data again, which could possibly cost several days.

@stas00 @thomasw21 @lhoestq Any updates on this problem after 1 year passed?

yqy2001 avatar Jan 25 '24 09:01 yqy2001

any update?

dancingpipi avatar Feb 01 '24 10:02 dancingpipi

No update so far, I wonder if someone implemented a resumable pytorch Sampler somwhere.

Then regarding resuming a streaming dataset, we'd first like to have an efficient way to skip shards automatically but this is not implemented yet

lhoestq avatar Feb 02 '24 09:02 lhoestq

I opened a draft here for IterableDataset: https://github.com/huggingface/datasets/pull/6658

"""Requires https://github.com/huggingface/datasets/pull/6658 (WIP)"""
from datasets import load_dataset
from torch.utils.data import DataLoader

ds = load_dataset(..., streaming=True)
# ds = ds.map(tokenize)
# ds = ds.shuffle(seed=42, buffer_size=1000)

# Init the dataset state_dict, or load it from a checkpoint
dataset_state_dict = ds.state_dict()

# Resumable training loop
ds.load_state_dict(dataset_state_dict)
dataloader = DataLoader(ds, batch_size=batch_size)
for step, batch in enumerate(dataloader):
    ...
    if step % save_steps == 0:
        dataset_state_dict = ds.state_dict()

lhoestq avatar Feb 19 '24 15:02 lhoestq

Hi @lhoestq - can you provide more information and how to implement on saving and restoring vanilla DataLoader states with map-style datasets?

jwliu36 avatar Feb 21 '24 08:02 jwliu36

For now the easiest is probably to use the vanilla DataLoader only for batching and multiprocessing, and implement the resuming logic using a Dataset (it has .select() to skip examples) and a dataset_state_dict:

from datasets import load_dataset
from torch.utils.data import DataLoader

ds = load_dataset(...)
# ds = ds.map(tokenize)
# ds = ds.shuffle(seed=42)

# Init the dataset state_dict, or load it from a checkpoint
dataset_state_dict = {"step": 0}  

# Resumable training loop
start_step = dataset_state_dict["step"]
dataloader = DataLoader(ds.select(range(start_step * batch_size, len(ds))), batch_size=batch_size)
for step, batch in enumerate(dataloader, start=start_step):
    ...
    if step % save_steps == 0:
        dataset_state_dict = {"step": step}

lhoestq avatar Feb 21 '24 11:02 lhoestq

Hello, I found a similar implementation online that seems to solve your problem. https://github.com/facebookresearch/vissl/blob/main/vissl/data/data_helper.py#L93 it looks like we can set_start_iter in StatefulDistributedSampler to implement the stateful resume requirement we want.

xgbj avatar Mar 19 '24 02:03 xgbj

Hi y'all, @lhoestq I wanted to flag that we currently have a StatefulDataLoader in pytorch/data/torchdata that has state_dict/load_state_dict methods, which will call a dataset's state_dict/load_state_dict methods but also handle multiprocessing under the hood. Any chance we can collaborate on this and try to get them to work well together? Please have a look here for some basic examples: https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader#saving-and-loading-state

andrewkho avatar Apr 29 '24 22:04 andrewkho

Fantastic ! This will help pushing our IterableDataset state_dict implementation at https://github.com/huggingface/datasets/pull/6658 :) I'll check if there is anything missing to maker them work together, and add tests and some docs referring to the StatefulDataLoader :)

lhoestq avatar Apr 30 '24 09:04 lhoestq

Ah I just saw this disclaimer in the torchdata README and it feels like people should not rely on it. Should the StatefulDataLoader live elsewhere @andrewkho ?

⚠️ As of July 2023, we have paused active development on TorchData and have paused new releases. We have learnt a lot from building it and hearing from users, but also believe we need to re-evaluate the technical design and approach given how much the industry has changed since we began the project. During the rest of 2023 we will be re-evaluating our plans in this space. Please reach out if you suggestions or comments (please use https://github.com/pytorch/data/issues/1196 for feedback).

lhoestq avatar Apr 30 '24 12:04 lhoestq

@lhoestq Good find, we are in the midst of updating this disclaimer as we're re-starting development and regular releases, though our approach will be to iterate on DL V1 (ie StatefulDataLoader) instead of continuing development on datapipes+DLV2. Let's discuss on a call at some point to figure out the best path forward!

andrewkho avatar Apr 30 '24 18:04 andrewkho

As a heads up, IterableDataset state_dict has been added in https://github.com/huggingface/datasets/pull/6658

...and it works out of the box with the torchdata StatefulDataLoader :)

See the docs at https://huggingface.co/docs/datasets/main/en/use_with_pytorch#checkpoint-and-resume

lhoestq avatar Jul 22 '24 11:07 lhoestq

amazing! Thank you, @lhoestq

does it work with non-iterable dataset as well? the docs only mention iterable dataset

stas00 avatar Jul 24 '24 20:07 stas00

It's for iterable dataset only. For regular dataset I believe the sampler should implement state_dict, but maybe @andrewkho might know best how to resume a regular dataset with torchdata

lhoestq avatar Jul 25 '24 12:07 lhoestq

@stas00 stateful dataloader will save and resume samplers for map style datasets. If no state_dict/load_state_dict is provided by the sampler, it will naively skip samples to fast forward. See here for more details https://github.com/pytorch/data/blob/main/torchdata/stateful_dataloader/README.md

Hope this helps!

andrewkho avatar Jul 25 '24 15:07 andrewkho

Thank you very much for clarifying that, Andrew.

stas00 avatar Jul 25 '24 17:07 stas00