interleave_datasets seed with multiple workers
Describe the bug
Using interleave_datasets with multiple dataloader workers and a seed set causes the same dataset sampling order across all workers.
Should the seed be modulated with the worker id?
Steps to reproduce the bug
See above
Expected behavior
See above
Environment info
datasetsversion: 3.5.1- Platform: macOS-15.4.1-arm64-arm-64bit
- Python version: 3.12.9
huggingface_hubversion: 0.30.2- PyArrow version: 19.0.1
- Pandas version: 2.2.3
fsspecversion: 2024.12.0
Hi ! It's already the case IIRC: the effective seed looks like seed + worker_id. Do you have a reproducible example ?
here is an example with shuffle
import itertools
import datasets
import multiprocessing
import torch.utils.data
def gen(shard):
worker_info = torch.utils.data.get_worker_info()
for i in range(10):
yield {'value': i, 'worker_id': worker_info.id}
def main():
ds = datasets.IterableDataset.from_generator(gen, gen_kwargs={'shard': list(range(8))})
ds = ds.shuffle(buffer_size=100, seed=1234)
dataloader = torch.utils.data.DataLoader(ds, batch_size=None, num_workers=8)
for i, ex in enumerate(itertools.islice(dataloader, 50)):
print(i, ex)
if __name__ == '__main__':
multiprocessing.set_start_method('spawn')
main()
python test.py
0 {'value': 8, 'worker_id': 0}
1 {'value': 8, 'worker_id': 1}
2 {'value': 8, 'worker_id': 2}
3 {'value': 8, 'worker_id': 3}
4 {'value': 8, 'worker_id': 4}
5 {'value': 8, 'worker_id': 5}
6 {'value': 8, 'worker_id': 6}
7 {'value': 8, 'worker_id': 7}
8 {'value': 9, 'worker_id': 0}
9 {'value': 9, 'worker_id': 1}
10 {'value': 9, 'worker_id': 2}
11 {'value': 9, 'worker_id': 3}
12 {'value': 9, 'worker_id': 4}
13 {'value': 9, 'worker_id': 5}
14 {'value': 9, 'worker_id': 6}
15 {'value': 9, 'worker_id': 7}
16 {'value': 5, 'worker_id': 0}
17 {'value': 5, 'worker_id': 1}
18 {'value': 5, 'worker_id': 2}
19 {'value': 5, 'worker_id': 3}
With interleave_datasets
import itertools
import datasets
import multiprocessing
import torch.utils.data
def gen(shard, value):
while True:
yield {'value': value}
def main():
ds = [
datasets.IterableDataset.from_generator(gen, gen_kwargs={'shard': list(range(8)), 'value': i})
for i in range(10)
]
ds = datasets.interleave_datasets(ds, probabilities=[1 / len(ds)] * len(ds), seed=1234)
dataloader = torch.utils.data.DataLoader(ds, batch_size=None, num_workers=8)
for i, ex in enumerate(itertools.islice(dataloader, 50)):
print(i, ex)
if __name__ == '__main__':
multiprocessing.set_start_method('spawn')
main()
python test.py
0 {'value': 9}
1 {'value': 9}
2 {'value': 9}
3 {'value': 9}
4 {'value': 9}
5 {'value': 9}
6 {'value': 9}
7 {'value': 9}
8 {'value': 3}
9 {'value': 3}
10 {'value': 3}
11 {'value': 3}
12 {'value': 3}
13 {'value': 3}
14 {'value': 3}
15 {'value': 3}
16 {'value': 9}
17 {'value': 9}
18 {'value': 9}
19 {'value': 9}
20 {'value': 9}
21 {'value': 9}
22 {'value': 9}
23 {'value': 9}
Same results after updating to datasets 3.6.0.
Ah my bad, shuffle() uses a global effective seed which is something like seed + epoch, which is used to do the same shards shuffle in each worker so that each worker have a non-overlapping set of shards:
https://github.com/huggingface/datasets/blob/b9efdc64c3bfb8f21f8a4a22b21bddd31ecd5a31/src/datasets/iterable_dataset.py#L2102-L2111
I think we should take into account the worker_id in a local seed for the buffer right after this line:
https://github.com/huggingface/datasets/blob/b9efdc64c3bfb8f21f8a4a22b21bddd31ecd5a31/src/datasets/iterable_dataset.py#L2151-L2153
like adding a new step that would propagate in the examples iterables or something like that:
ex_iterable = ex_iterable.shift_rngs(value=worker_id)
is this something you'd like to explore ? contributions on this subject are very welcome
Potentially, but busy. If anyone wants to take this up please feel free to, otherwise I may or may not revisit when I have free time.
For what it's worth I got around this with
class SeedGeneratorWithWorkerIterable(iterable_dataset._BaseExamplesIterable):
"""ExamplesIterable that seeds the rng with worker id."""
def __init__(
self,
ex_iterable: iterable_dataset._BaseExamplesIterable,
generator: np.random.Generator,
rank: int = 0,
):
"""Constructor."""
super().__init__()
self.ex_iterable = ex_iterable
self.generator = generator
self.rank = rank
def _init_state_dict(self) -> dict:
self._state_dict = self.ex_iterable._init_state_dict()
return self._state_dict
def __iter__(self):
"""Data iterator."""
effective_seed = copy.deepcopy(self.generator).integers(0, 1 << 63) - self.rank
effective_seed = (1 << 63) + effective_seed if effective_seed < 0 else effective_seed
generator = np.random.default_rng(effective_seed)
self.ex_iterable = self.ex_iterable.shuffle_data_sources(generator)
if self._state_dict:
self._state_dict = self.ex_iterable._init_state_dict()
yield from iter(self.ex_iterable)
def shuffle_data_sources(self, generator):
"""Shuffle data sources."""
ex_iterable = self.ex_iterable.shuffle_data_sources(generator)
return SeedGeneratorWithWorkerIterable(ex_iterable, generator=generator, rank=self.rank)
def shard_data_sources(self, num_shards: int, index: int, contiguous=True): # noqa: FBT002
"""Shard data sources."""
ex_iterable = self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous)
return SeedGeneratorWithWorkerIterable(ex_iterable, generator=self.generator, rank=index)
@property
def is_typed(self):
return self.ex_iterable.is_typed
@property
def features(self):
return self.ex_iterable.features
@property
def num_shards(self) -> int:
"""Number of shards."""
return self.ex_iterable.num_shards
Thanks for the detailed insights!
After reviewing the issue and the current implementation in iterable_dataset.py, I can confirm the cause:
When using interleave_datasets(..., seed=...) with num_workers > 1 (e.g. via DataLoader), the same RNG state is shared across workers — which leads to each worker producing identical sample sequences. This is because the seed is not modulated by worker_id, unlike the usual approach in shuffle() where seed is adjusted using the epoch.
As @lhoestq suggested, a proper fix would involve introducing something like:
ex_iterable = ex_iterable.shift_rngs(worker_id)
@jonathanasdf Also really appreciate the workaround implementation shared above — that was helpful to validate the behavior and will help shape the general solution.