datasets icon indicating copy to clipboard operation
datasets copied to clipboard

[Resumable IterableDataset] Add IterableDataset state_dict

Open lhoestq opened this issue 1 year ago • 14 comments

A simple implementation of a mechanism to resume an IterableDataset. This is WIP and untested.

Example:

from datasets import Dataset, concatenate_datasets


ds = Dataset.from_dict({"a": range(5)}).to_iterable_dataset(num_shards=3)
ds = concatenate_datasets([ds] * 2)

print(f"{ds.state_dict()=}")
for i, example in enumerate(ds):
    print(example)
    if i == 6:
        state_dict = ds.state_dict()
ds.load_state_dict(state_dict)
print(f"{ds.state_dict()=}")
for example in ds:
    print(example)

returns

ds.state_dict()={'ex_iterable_idx': 0, 'ex_iterables': [{'shard_idx': 0, 'shard_example_idx': 0}, {'shard_idx': 0, 'shard_example_idx': 0}]}
{'a': 0}
{'a': 1}
{'a': 2}
{'a': 3}
{'a': 4}
{'a': 0}
{'a': 1}
{'a': 2}
{'a': 3}
{'a': 4}
ds.state_dict()={'ex_iterable_idx': 1, 'ex_iterables': [{'shard_idx': 3, 'shard_example_idx': 0}, {'shard_idx': 0, 'shard_example_idx': 2}]}
{'a': 2}
{'a': 3}
{'a': 4}

lhoestq avatar Feb 11 '24 20:02 lhoestq

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

would be nice to have this feature in the new dataset release!

bwanglzu avatar Apr 11 '24 10:04 bwanglzu

Before finalising this this I'd like to make sure this philosophy makes sense for other libs like accelerate for example.

cc @muellerzr I'd love your feedback on this one cc @LysandreJik also if you think other people should take a look

lhoestq avatar Apr 11 '24 13:04 lhoestq

One design question though: what's the logic behind self._state_dict rather than having it all be state_dict?

The _state_dict is the internal object that is updated in-place while you iterate on the dataset.

We need to copy it every time the user accesses it.

Otherwise we would get

state_dict = ds.state_dict()
for x in ds:
    assert ds.state_dict() == state_dict  # and actually `assert ds.state_dict() is state_dict`

The state is updated in-place since it's made of dictionaries that are shared with the steps in the IterableDataset pipeline.

lhoestq avatar Apr 15 '24 10:04 lhoestq

What do you think of making it a full property with a docstring explicitly stating users shouldn’t call/modify it directly?

I can imagine some exploratory users getting curious

muellerzr avatar Apr 15 '24 10:04 muellerzr

I don't think users read docstrings of properties that often. What about explaining the logic in the .state_dict() docstring ? This also feels aligned with the way .state_dict() and .load_state_dict() works in pytorch (you should use load_state_dict to load a modified copy of the state dict)

lhoestq avatar Apr 15 '24 10:04 lhoestq

Sure, I can agree with that!

muellerzr avatar Apr 15 '24 12:04 muellerzr

Just a small note mentioning returns a copy of the state dict should be enough imo

muellerzr avatar Apr 15 '24 12:04 muellerzr

looking forward as well for this PR to be merge

samsja avatar May 07 '24 10:05 samsja

I don't think users read docstrings of properties that often. What about explaining the logic in the .state_dict() docstring ? This also feels aligned with the way .state_dict() and .load_state_dict() works in pytorch (you should use load_state_dict to load a modified copy of the state dict)

Hi, I'm experimenting with LLM pretraining using your code. I found that the time of resuming an iterable dataset can be reduced to 5% (my streaming process includes tokenization), but I'm not sure if I'm using it correctly. Could you help me check it? Thanks.

class CustomTrainer(Trainer):
    def _save_rng_state(self, output_dir):
        super()._save_rng_state(output_dir)
        if self.args.should_save:
            with open(os.path.join(output_dir, f'iterable_data_state_dict.json'), 'w', encoding='utf-8') as fo:
                json.dump(self.train_dataset.state_dict(), fo, ensure_ascii=False)
    dataset = <A IterableDataset constructed by (interleave, map(tokenization))>
    lask_ckpt_iterable_data_state_dict_file_path = os.path.join(training_args.resume_from_checkpoint, f'iterable_data_state_dict.json')
    if os.path.exists(lask_ckpt_iterable_data_state_dict_file_path) and finetuning_args.load_iteratable_state_dict:
        if not training_args.ignore_data_skip:
            raise ValueError(f'Found `iterable_data_state_dict_file_path`: `{lask_ckpt_iterable_data_state_dict_file_path}`. Please set `ignore_data_skip`=True to skip tokenization.')
        with open(lask_ckpt_iterable_data_state_dict_file_path) as f:
            lask_ckpt_iterable_data_state_dict = json.load(f)
            dataset.load_state_dict(lask_ckpt_iterable_data_state_dict)
            logger.info(f'Loading `iterable_data_state_dict` from {lask_ckpt_iterable_data_state_dict_file_path}')

fyubang avatar May 20 '24 08:05 fyubang

it sounds good to me :)

lhoestq avatar May 21 '24 09:05 lhoestq

@lhoestq Hi, if I set prefetch, does this dataset work well?

uygnef avatar May 24 '24 09:05 uygnef

It does work well if you prefetch and then resume from a state, but you might lose the samples that were in the prefetch buffer of the DataLoader (which could be acceptable in some circumstances).

Fortunately we're about to ship an integration with the new StatefulDataLoader from torchdata which can help on this matter :)

lhoestq avatar May 24 '24 09:05 lhoestq

yeah, what I meant is that prefetch might drop a few data entries. really looking forward to the new StatefulDataLoader. :)

uygnef avatar May 24 '24 09:05 uygnef

Show benchmarks

PyArrow==8.0.0

Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.005788 / 0.011353 (-0.005564) 0.004036 / 0.011008 (-0.006972) 0.064720 / 0.038508 (0.026212) 0.034990 / 0.023109 (0.011881) 0.245488 / 0.275898 (-0.030410) 0.272596 / 0.323480 (-0.050884) 0.003170 / 0.007986 (-0.004815) 0.002867 / 0.004328 (-0.001461) 0.049961 / 0.004250 (0.045711) 0.050951 / 0.037052 (0.013899) 0.257757 / 0.258489 (-0.000732) 0.292957 / 0.293841 (-0.000884) 0.027739 / 0.128546 (-0.100807) 0.010942 / 0.075646 (-0.064705) 0.205153 / 0.419271 (-0.214118) 0.037892 / 0.043533 (-0.005641) 0.247536 / 0.255139 (-0.007603) 0.267239 / 0.283200 (-0.015960) 0.021490 / 0.141683 (-0.120193) 1.107306 / 1.452155 (-0.344848) 1.144675 / 1.492716 (-0.348041)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.103212 / 0.018006 (0.085205) 0.315174 / 0.000490 (0.314684) 0.000229 / 0.000200 (0.000029) 0.000044 / 0.000054 (-0.000011)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.019771 / 0.037411 (-0.017641) 0.064033 / 0.014526 (0.049507) 0.076751 / 0.176557 (-0.099805) 0.122615 / 0.737135 (-0.614521) 0.078490 / 0.296338 (-0.217848)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.286236 / 0.215209 (0.071027) 2.841469 / 2.077655 (0.763814) 1.514079 / 1.504120 (0.009959) 1.393792 / 1.541195 (-0.147403) 1.432741 / 1.468490 (-0.035749) 0.571003 / 4.584777 (-4.013774) 2.369031 / 3.745712 (-1.376681) 2.825246 / 5.269862 (-2.444616) 1.858524 / 4.565676 (-2.707153) 0.065366 / 0.424275 (-0.358909) 0.005107 / 0.007607 (-0.002500) 0.341010 / 0.226044 (0.114965) 3.443894 / 2.268929 (1.174966) 1.879192 / 55.444624 (-53.565433) 1.603046 / 6.876477 (-5.273431) 1.807639 / 2.142072 (-0.334433) 0.646726 / 4.805227 (-4.158502) 0.119409 / 6.500664 (-6.381255) 0.044564 / 0.075469 (-0.030905)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 0.971026 / 1.841788 (-0.870762) 12.593884 / 8.074308 (4.519576) 10.305243 / 10.191392 (0.113851) 0.132018 / 0.680424 (-0.548406) 0.014387 / 0.534201 (-0.519814) 0.288597 / 0.579283 (-0.290686) 0.267373 / 0.434364 (-0.166991) 0.325626 / 0.540337 (-0.214711) 0.488808 / 1.386936 (-0.898128)
PyArrow==latest
Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.005991 / 0.011353 (-0.005362) 0.004028 / 0.011008 (-0.006980) 0.051951 / 0.038508 (0.013443) 0.036870 / 0.023109 (0.013761) 0.263777 / 0.275898 (-0.012122) 0.290914 / 0.323480 (-0.032566) 0.004594 / 0.007986 (-0.003392) 0.002971 / 0.004328 (-0.001357) 0.049699 / 0.004250 (0.045449) 0.044939 / 0.037052 (0.007887) 0.275055 / 0.258489 (0.016566) 0.316244 / 0.293841 (0.022403) 0.030501 / 0.128546 (-0.098045) 0.011197 / 0.075646 (-0.064449) 0.058718 / 0.419271 (-0.360554) 0.034926 / 0.043533 (-0.008607) 0.259172 / 0.255139 (0.004033) 0.280127 / 0.283200 (-0.003072) 0.019775 / 0.141683 (-0.121908) 1.169468 / 1.452155 (-0.282687) 1.178098 / 1.492716 (-0.314619)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.101633 / 0.018006 (0.083626) 0.314684 / 0.000490 (0.314194) 0.000224 / 0.000200 (0.000024) 0.000055 / 0.000054 (0.000001)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.024071 / 0.037411 (-0.013341) 0.079894 / 0.014526 (0.065368) 0.090915 / 0.176557 (-0.085642) 0.132397 / 0.737135 (-0.604738) 0.091919 / 0.296338 (-0.204419)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.296237 / 0.215209 (0.081028) 2.891752 / 2.077655 (0.814097) 1.551937 / 1.504120 (0.047817) 1.414179 / 1.541195 (-0.127016) 1.450192 / 1.468490 (-0.018298) 0.556272 / 4.584777 (-4.028504) 0.952374 / 3.745712 (-2.793339) 2.709450 / 5.269862 (-2.560411) 1.771251 / 4.565676 (-2.794426) 0.061873 / 0.424275 (-0.362402) 0.005058 / 0.007607 (-0.002549) 0.344790 / 0.226044 (0.118746) 3.398982 / 2.268929 (1.130053) 1.905832 / 55.444624 (-53.538792) 1.632357 / 6.876477 (-5.244120) 1.822913 / 2.142072 (-0.319160) 0.643426 / 4.805227 (-4.161802) 0.117321 / 6.500664 (-6.383343) 0.042107 / 0.075469 (-0.033363)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 0.974921 / 1.841788 (-0.866867) 12.497801 / 8.074308 (4.423493) 11.216174 / 10.191392 (1.024782) 0.135288 / 0.680424 (-0.545136) 0.016731 / 0.534201 (-0.517470) 0.287987 / 0.579283 (-0.291296) 0.130246 / 0.434364 (-0.304117) 0.323282 / 0.540337 (-0.217055) 0.414595 / 1.386936 (-0.972341)

github-actions[bot] avatar Jun 03 '24 19:06 github-actions[bot]

@lhoestq Hello, I'm wondering if there are any solutions to work with shuffle now. I've noticed the caveats in docs,

examples from shuffle buffers are lost when resuming and the buffers are refilled with new data

yzhangcs avatar Jun 20 '24 11:06 yzhangcs

Hi ! I haven't experimented with implementing state_dict for the shuffle buffer. Not sure if this is a good idea to add this, given a shuffle buffer can be quite big and poses serialization challenges.

It shouldn't be difficult to experiment with a simple implementation in BufferShuffledExamplesIterable though

lhoestq avatar Jun 20 '24 13:06 lhoestq

@lhoestq thank you for your quick response! I'll try it :}

yzhangcs avatar Jun 20 '24 13:06 yzhangcs

@lhoestq Hi, just revise the BufferShuffledExamplesIterable and it works


class BufferShuffledExamplesIterable(datasets.iterable_dataset.BufferShuffledExamplesIterable):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def _init_state_dict(self) -> dict:
        self._state_dict = self.ex_iterable._init_state_dict()
        self._state_dict['mem_buffer'] = ([],)
        self._state_dict['gloabl_example_index'] = 0
        return self._state_dict

    def __iter__(self):
        buffer_size = self.buffer_size
        rng = deepcopy(self.generator)
        indices_iterator = self._iter_random_indices(rng, buffer_size)
        # this is the shuffle buffer that we keep in memory
        mem_buffer = self._state_dict['mem_buffer'][0]
        gloabl_example_index_start = self._state_dict["gloabl_example_index"] if self._state_dict else 0
        # skip already consumed ones
        for i in range(gloabl_example_index_start):
            _ = next(indices_iterator)
        for x in self.ex_iterable:
            if len(mem_buffer) == buffer_size:  # if the buffer is full, pick and example from it
                i = next(indices_iterator)
                if self._state_dict:
                    self._state_dict['gloabl_example_index'] += 1
                yield mem_buffer[i]
                mem_buffer[i] = x  # replace the picked example by a new one
            else:  # otherwise, keep filling the buffer
                mem_buffer.append(x)
        # when we run out of examples, we shuffle the remaining examples in the buffer and yield them
        rng.shuffle(mem_buffer)
        yield from mem_buffer

    def shuffle_data_sources(self, generator: np.random.Generator) -> BufferShuffledExamplesIterable:
        """Shuffle the wrapped examples iterable as well as the shuffling buffer."""
        return BufferShuffledExamplesIterable(
            self.ex_iterable.shuffle_data_sources(generator), buffer_size=self.buffer_size, generator=generator
        )

    def shard_data_sources(self, worker_id: int, num_workers: int) -> BufferShuffledExamplesIterable:
        """Keep only the requested shard."""
        return BufferShuffledExamplesIterable(
            self.ex_iterable.shard_data_sources(worker_id, num_workers),
            buffer_size=self.buffer_size,
            generator=self.generator,
        )

    def load_state_dict(self, state_dict: dict) -> dict:
        def _inner_load_state_dict(state, new_state):
            if new_state is not None and isinstance(state, dict):
                for key in state:
                    state[key] = _inner_load_state_dict(state[key], new_state[key])
                return state
            elif new_state is not None and isinstance(state, list):
                for i in range(len(state)):
                    state[i] = _inner_load_state_dict(state[i], new_state[i])
                return state
            return new_state

        return _inner_load_state_dict(self._state_dict, state_dict)

I've noticed that it uses significantly more RAM than the original version and experiences a considerable decrease in GPU utilization. Could you offer some suggestions to address this issue?

or is it prohibited to maintain sth except for simple indices that small enough for each worker 😢

yzhangcs avatar Jun 23 '24 07:06 yzhangcs

Some ExamplesIterable copy and store old versions of the state_dict of parent ExamplesIterable. It is the case for example for batched map() (state_dict of beginning of the batch) or interleave_dataset() (state_dict of the previous step since it buffers one example to know if the iterable is exhausted).

Copying a shuffle buffer takes some RAM and some time, which can slow down the data loading pipeline. Maybe the examples in the shuffle buffer shouldn't not be copied (only do a shallow copy of the list), this would surely help.

lhoestq avatar Jun 24 '24 09:06 lhoestq