datasets
datasets copied to clipboard
[Resumable IterableDataset] Add IterableDataset state_dict
A simple implementation of a mechanism to resume an IterableDataset. This is WIP and untested.
Example:
from datasets import Dataset, concatenate_datasets
ds = Dataset.from_dict({"a": range(5)}).to_iterable_dataset(num_shards=3)
ds = concatenate_datasets([ds] * 2)
print(f"{ds.state_dict()=}")
for i, example in enumerate(ds):
print(example)
if i == 6:
state_dict = ds.state_dict()
ds.load_state_dict(state_dict)
print(f"{ds.state_dict()=}")
for example in ds:
print(example)
returns
ds.state_dict()={'ex_iterable_idx': 0, 'ex_iterables': [{'shard_idx': 0, 'shard_example_idx': 0}, {'shard_idx': 0, 'shard_example_idx': 0}]}
{'a': 0}
{'a': 1}
{'a': 2}
{'a': 3}
{'a': 4}
{'a': 0}
{'a': 1}
{'a': 2}
{'a': 3}
{'a': 4}
ds.state_dict()={'ex_iterable_idx': 1, 'ex_iterables': [{'shard_idx': 3, 'shard_example_idx': 0}, {'shard_idx': 0, 'shard_example_idx': 2}]}
{'a': 2}
{'a': 3}
{'a': 4}
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
would be nice to have this feature in the new dataset release!
Before finalising this this I'd like to make sure this philosophy makes sense for other libs like accelerate for example.
cc @muellerzr I'd love your feedback on this one cc @LysandreJik also if you think other people should take a look
One design question though: what's the logic behind self._state_dict rather than having it all be state_dict?
The _state_dict is the internal object that is updated in-place while you iterate on the dataset.
We need to copy it every time the user accesses it.
Otherwise we would get
state_dict = ds.state_dict()
for x in ds:
assert ds.state_dict() == state_dict # and actually `assert ds.state_dict() is state_dict`
The state is updated in-place since it's made of dictionaries that are shared with the steps in the IterableDataset pipeline.
What do you think of making it a full property with a docstring explicitly stating users shouldn’t call/modify it directly?
I can imagine some exploratory users getting curious
I don't think users read docstrings of properties that often. What about explaining the logic in the .state_dict() docstring ? This also feels aligned with the way .state_dict() and .load_state_dict() works in pytorch (you should use load_state_dict to load a modified copy of the state dict)
Sure, I can agree with that!
Just a small note mentioning returns a copy of the state dict should be enough imo
looking forward as well for this PR to be merge
I don't think users read docstrings of properties that often. What about explaining the logic in the
.state_dict()docstring ? This also feels aligned with the way.state_dict()and.load_state_dict()works in pytorch (you should use load_state_dict to load a modified copy of the state dict)
Hi, I'm experimenting with LLM pretraining using your code. I found that the time of resuming an iterable dataset can be reduced to 5% (my streaming process includes tokenization), but I'm not sure if I'm using it correctly. Could you help me check it? Thanks.
class CustomTrainer(Trainer):
def _save_rng_state(self, output_dir):
super()._save_rng_state(output_dir)
if self.args.should_save:
with open(os.path.join(output_dir, f'iterable_data_state_dict.json'), 'w', encoding='utf-8') as fo:
json.dump(self.train_dataset.state_dict(), fo, ensure_ascii=False)
dataset = <A IterableDataset constructed by (interleave, map(tokenization))>
lask_ckpt_iterable_data_state_dict_file_path = os.path.join(training_args.resume_from_checkpoint, f'iterable_data_state_dict.json')
if os.path.exists(lask_ckpt_iterable_data_state_dict_file_path) and finetuning_args.load_iteratable_state_dict:
if not training_args.ignore_data_skip:
raise ValueError(f'Found `iterable_data_state_dict_file_path`: `{lask_ckpt_iterable_data_state_dict_file_path}`. Please set `ignore_data_skip`=True to skip tokenization.')
with open(lask_ckpt_iterable_data_state_dict_file_path) as f:
lask_ckpt_iterable_data_state_dict = json.load(f)
dataset.load_state_dict(lask_ckpt_iterable_data_state_dict)
logger.info(f'Loading `iterable_data_state_dict` from {lask_ckpt_iterable_data_state_dict_file_path}')
it sounds good to me :)
@lhoestq Hi, if I set prefetch, does this dataset work well?
It does work well if you prefetch and then resume from a state, but you might lose the samples that were in the prefetch buffer of the DataLoader (which could be acceptable in some circumstances).
Fortunately we're about to ship an integration with the new StatefulDataLoader from torchdata which can help on this matter :)
yeah, what I meant is that prefetch might drop a few data entries. really looking forward to the new StatefulDataLoader. :)
Show benchmarks
PyArrow==8.0.0
Show updated benchmarks!
Benchmark: benchmark_array_xd.json
| metric | read_batch_formatted_as_numpy after write_array2d | read_batch_formatted_as_numpy after write_flattened_sequence | read_batch_formatted_as_numpy after write_nested_sequence | read_batch_unformated after write_array2d | read_batch_unformated after write_flattened_sequence | read_batch_unformated after write_nested_sequence | read_col_formatted_as_numpy after write_array2d | read_col_formatted_as_numpy after write_flattened_sequence | read_col_formatted_as_numpy after write_nested_sequence | read_col_unformated after write_array2d | read_col_unformated after write_flattened_sequence | read_col_unformated after write_nested_sequence | read_formatted_as_numpy after write_array2d | read_formatted_as_numpy after write_flattened_sequence | read_formatted_as_numpy after write_nested_sequence | read_unformated after write_array2d | read_unformated after write_flattened_sequence | read_unformated after write_nested_sequence | write_array2d | write_flattened_sequence | write_nested_sequence |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| new / old (diff) | 0.005788 / 0.011353 (-0.005564) | 0.004036 / 0.011008 (-0.006972) | 0.064720 / 0.038508 (0.026212) | 0.034990 / 0.023109 (0.011881) | 0.245488 / 0.275898 (-0.030410) | 0.272596 / 0.323480 (-0.050884) | 0.003170 / 0.007986 (-0.004815) | 0.002867 / 0.004328 (-0.001461) | 0.049961 / 0.004250 (0.045711) | 0.050951 / 0.037052 (0.013899) | 0.257757 / 0.258489 (-0.000732) | 0.292957 / 0.293841 (-0.000884) | 0.027739 / 0.128546 (-0.100807) | 0.010942 / 0.075646 (-0.064705) | 0.205153 / 0.419271 (-0.214118) | 0.037892 / 0.043533 (-0.005641) | 0.247536 / 0.255139 (-0.007603) | 0.267239 / 0.283200 (-0.015960) | 0.021490 / 0.141683 (-0.120193) | 1.107306 / 1.452155 (-0.344848) | 1.144675 / 1.492716 (-0.348041) |
Benchmark: benchmark_getitem_100B.json
| metric | get_batch_of_1024_random_rows | get_batch_of_1024_rows | get_first_row | get_last_row |
|---|---|---|---|---|
| new / old (diff) | 0.103212 / 0.018006 (0.085205) | 0.315174 / 0.000490 (0.314684) | 0.000229 / 0.000200 (0.000029) | 0.000044 / 0.000054 (-0.000011) |
Benchmark: benchmark_indices_mapping.json
| metric | select | shard | shuffle | sort | train_test_split |
|---|---|---|---|---|---|
| new / old (diff) | 0.019771 / 0.037411 (-0.017641) | 0.064033 / 0.014526 (0.049507) | 0.076751 / 0.176557 (-0.099805) | 0.122615 / 0.737135 (-0.614521) | 0.078490 / 0.296338 (-0.217848) |
Benchmark: benchmark_iterating.json
| metric | read 5000 | read 50000 | read_batch 50000 10 | read_batch 50000 100 | read_batch 50000 1000 | read_formatted numpy 5000 | read_formatted pandas 5000 | read_formatted tensorflow 5000 | read_formatted torch 5000 | read_formatted_batch numpy 5000 10 | read_formatted_batch numpy 5000 1000 | shuffled read 5000 | shuffled read 50000 | shuffled read_batch 50000 10 | shuffled read_batch 50000 100 | shuffled read_batch 50000 1000 | shuffled read_formatted numpy 5000 | shuffled read_formatted_batch numpy 5000 10 | shuffled read_formatted_batch numpy 5000 1000 |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| new / old (diff) | 0.286236 / 0.215209 (0.071027) | 2.841469 / 2.077655 (0.763814) | 1.514079 / 1.504120 (0.009959) | 1.393792 / 1.541195 (-0.147403) | 1.432741 / 1.468490 (-0.035749) | 0.571003 / 4.584777 (-4.013774) | 2.369031 / 3.745712 (-1.376681) | 2.825246 / 5.269862 (-2.444616) | 1.858524 / 4.565676 (-2.707153) | 0.065366 / 0.424275 (-0.358909) | 0.005107 / 0.007607 (-0.002500) | 0.341010 / 0.226044 (0.114965) | 3.443894 / 2.268929 (1.174966) | 1.879192 / 55.444624 (-53.565433) | 1.603046 / 6.876477 (-5.273431) | 1.807639 / 2.142072 (-0.334433) | 0.646726 / 4.805227 (-4.158502) | 0.119409 / 6.500664 (-6.381255) | 0.044564 / 0.075469 (-0.030905) |
Benchmark: benchmark_map_filter.json
| metric | filter | map fast-tokenizer batched | map identity | map identity batched | map no-op batched | map no-op batched numpy | map no-op batched pandas | map no-op batched pytorch | map no-op batched tensorflow |
|---|---|---|---|---|---|---|---|---|---|
| new / old (diff) | 0.971026 / 1.841788 (-0.870762) | 12.593884 / 8.074308 (4.519576) | 10.305243 / 10.191392 (0.113851) | 0.132018 / 0.680424 (-0.548406) | 0.014387 / 0.534201 (-0.519814) | 0.288597 / 0.579283 (-0.290686) | 0.267373 / 0.434364 (-0.166991) | 0.325626 / 0.540337 (-0.214711) | 0.488808 / 1.386936 (-0.898128) |
Show updated benchmarks!
Benchmark: benchmark_array_xd.json
| metric | read_batch_formatted_as_numpy after write_array2d | read_batch_formatted_as_numpy after write_flattened_sequence | read_batch_formatted_as_numpy after write_nested_sequence | read_batch_unformated after write_array2d | read_batch_unformated after write_flattened_sequence | read_batch_unformated after write_nested_sequence | read_col_formatted_as_numpy after write_array2d | read_col_formatted_as_numpy after write_flattened_sequence | read_col_formatted_as_numpy after write_nested_sequence | read_col_unformated after write_array2d | read_col_unformated after write_flattened_sequence | read_col_unformated after write_nested_sequence | read_formatted_as_numpy after write_array2d | read_formatted_as_numpy after write_flattened_sequence | read_formatted_as_numpy after write_nested_sequence | read_unformated after write_array2d | read_unformated after write_flattened_sequence | read_unformated after write_nested_sequence | write_array2d | write_flattened_sequence | write_nested_sequence |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| new / old (diff) | 0.005991 / 0.011353 (-0.005362) | 0.004028 / 0.011008 (-0.006980) | 0.051951 / 0.038508 (0.013443) | 0.036870 / 0.023109 (0.013761) | 0.263777 / 0.275898 (-0.012122) | 0.290914 / 0.323480 (-0.032566) | 0.004594 / 0.007986 (-0.003392) | 0.002971 / 0.004328 (-0.001357) | 0.049699 / 0.004250 (0.045449) | 0.044939 / 0.037052 (0.007887) | 0.275055 / 0.258489 (0.016566) | 0.316244 / 0.293841 (0.022403) | 0.030501 / 0.128546 (-0.098045) | 0.011197 / 0.075646 (-0.064449) | 0.058718 / 0.419271 (-0.360554) | 0.034926 / 0.043533 (-0.008607) | 0.259172 / 0.255139 (0.004033) | 0.280127 / 0.283200 (-0.003072) | 0.019775 / 0.141683 (-0.121908) | 1.169468 / 1.452155 (-0.282687) | 1.178098 / 1.492716 (-0.314619) |
Benchmark: benchmark_getitem_100B.json
| metric | get_batch_of_1024_random_rows | get_batch_of_1024_rows | get_first_row | get_last_row |
|---|---|---|---|---|
| new / old (diff) | 0.101633 / 0.018006 (0.083626) | 0.314684 / 0.000490 (0.314194) | 0.000224 / 0.000200 (0.000024) | 0.000055 / 0.000054 (0.000001) |
Benchmark: benchmark_indices_mapping.json
| metric | select | shard | shuffle | sort | train_test_split |
|---|---|---|---|---|---|
| new / old (diff) | 0.024071 / 0.037411 (-0.013341) | 0.079894 / 0.014526 (0.065368) | 0.090915 / 0.176557 (-0.085642) | 0.132397 / 0.737135 (-0.604738) | 0.091919 / 0.296338 (-0.204419) |
Benchmark: benchmark_iterating.json
| metric | read 5000 | read 50000 | read_batch 50000 10 | read_batch 50000 100 | read_batch 50000 1000 | read_formatted numpy 5000 | read_formatted pandas 5000 | read_formatted tensorflow 5000 | read_formatted torch 5000 | read_formatted_batch numpy 5000 10 | read_formatted_batch numpy 5000 1000 | shuffled read 5000 | shuffled read 50000 | shuffled read_batch 50000 10 | shuffled read_batch 50000 100 | shuffled read_batch 50000 1000 | shuffled read_formatted numpy 5000 | shuffled read_formatted_batch numpy 5000 10 | shuffled read_formatted_batch numpy 5000 1000 |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| new / old (diff) | 0.296237 / 0.215209 (0.081028) | 2.891752 / 2.077655 (0.814097) | 1.551937 / 1.504120 (0.047817) | 1.414179 / 1.541195 (-0.127016) | 1.450192 / 1.468490 (-0.018298) | 0.556272 / 4.584777 (-4.028504) | 0.952374 / 3.745712 (-2.793339) | 2.709450 / 5.269862 (-2.560411) | 1.771251 / 4.565676 (-2.794426) | 0.061873 / 0.424275 (-0.362402) | 0.005058 / 0.007607 (-0.002549) | 0.344790 / 0.226044 (0.118746) | 3.398982 / 2.268929 (1.130053) | 1.905832 / 55.444624 (-53.538792) | 1.632357 / 6.876477 (-5.244120) | 1.822913 / 2.142072 (-0.319160) | 0.643426 / 4.805227 (-4.161802) | 0.117321 / 6.500664 (-6.383343) | 0.042107 / 0.075469 (-0.033363) |
Benchmark: benchmark_map_filter.json
| metric | filter | map fast-tokenizer batched | map identity | map identity batched | map no-op batched | map no-op batched numpy | map no-op batched pandas | map no-op batched pytorch | map no-op batched tensorflow |
|---|---|---|---|---|---|---|---|---|---|
| new / old (diff) | 0.974921 / 1.841788 (-0.866867) | 12.497801 / 8.074308 (4.423493) | 11.216174 / 10.191392 (1.024782) | 0.135288 / 0.680424 (-0.545136) | 0.016731 / 0.534201 (-0.517470) | 0.287987 / 0.579283 (-0.291296) | 0.130246 / 0.434364 (-0.304117) | 0.323282 / 0.540337 (-0.217055) | 0.414595 / 1.386936 (-0.972341) |

@lhoestq Hello, I'm wondering if there are any solutions to work with shuffle now. I've noticed the caveats in docs,
examples from shuffle buffers are lost when resuming and the buffers are refilled with new data
Hi ! I haven't experimented with implementing state_dict for the shuffle buffer. Not sure if this is a good idea to add this, given a shuffle buffer can be quite big and poses serialization challenges.
It shouldn't be difficult to experiment with a simple implementation in BufferShuffledExamplesIterable though
@lhoestq thank you for your quick response! I'll try it :}
@lhoestq Hi, just revise the BufferShuffledExamplesIterable and it works
class BufferShuffledExamplesIterable(datasets.iterable_dataset.BufferShuffledExamplesIterable):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _init_state_dict(self) -> dict:
self._state_dict = self.ex_iterable._init_state_dict()
self._state_dict['mem_buffer'] = ([],)
self._state_dict['gloabl_example_index'] = 0
return self._state_dict
def __iter__(self):
buffer_size = self.buffer_size
rng = deepcopy(self.generator)
indices_iterator = self._iter_random_indices(rng, buffer_size)
# this is the shuffle buffer that we keep in memory
mem_buffer = self._state_dict['mem_buffer'][0]
gloabl_example_index_start = self._state_dict["gloabl_example_index"] if self._state_dict else 0
# skip already consumed ones
for i in range(gloabl_example_index_start):
_ = next(indices_iterator)
for x in self.ex_iterable:
if len(mem_buffer) == buffer_size: # if the buffer is full, pick and example from it
i = next(indices_iterator)
if self._state_dict:
self._state_dict['gloabl_example_index'] += 1
yield mem_buffer[i]
mem_buffer[i] = x # replace the picked example by a new one
else: # otherwise, keep filling the buffer
mem_buffer.append(x)
# when we run out of examples, we shuffle the remaining examples in the buffer and yield them
rng.shuffle(mem_buffer)
yield from mem_buffer
def shuffle_data_sources(self, generator: np.random.Generator) -> BufferShuffledExamplesIterable:
"""Shuffle the wrapped examples iterable as well as the shuffling buffer."""
return BufferShuffledExamplesIterable(
self.ex_iterable.shuffle_data_sources(generator), buffer_size=self.buffer_size, generator=generator
)
def shard_data_sources(self, worker_id: int, num_workers: int) -> BufferShuffledExamplesIterable:
"""Keep only the requested shard."""
return BufferShuffledExamplesIterable(
self.ex_iterable.shard_data_sources(worker_id, num_workers),
buffer_size=self.buffer_size,
generator=self.generator,
)
def load_state_dict(self, state_dict: dict) -> dict:
def _inner_load_state_dict(state, new_state):
if new_state is not None and isinstance(state, dict):
for key in state:
state[key] = _inner_load_state_dict(state[key], new_state[key])
return state
elif new_state is not None and isinstance(state, list):
for i in range(len(state)):
state[i] = _inner_load_state_dict(state[i], new_state[i])
return state
return new_state
return _inner_load_state_dict(self._state_dict, state_dict)
I've noticed that it uses significantly more RAM than the original version and experiences a considerable decrease in GPU utilization. Could you offer some suggestions to address this issue?
or is it prohibited to maintain sth except for simple indices that small enough for each worker 😢
Some ExamplesIterable copy and store old versions of the state_dict of parent ExamplesIterable. It is the case for example for batched map() (state_dict of beginning of the batch) or interleave_dataset() (state_dict of the previous step since it buffers one example to know if the iterable is exhausted).
Copying a shuffle buffer takes some RAM and some time, which can slow down the data loading pipeline. Maybe the examples in the shuffle buffer shouldn't not be copied (only do a shallow copy of the list), this would surely help.