datasets icon indicating copy to clipboard operation
datasets copied to clipboard

interleave_datasets seed with multiple workers

Open jonathanasdf opened this issue 8 months ago • 6 comments

Describe the bug

Using interleave_datasets with multiple dataloader workers and a seed set causes the same dataset sampling order across all workers.

Should the seed be modulated with the worker id?

Steps to reproduce the bug

See above

Expected behavior

See above

Environment info

  • datasets version: 3.5.1
  • Platform: macOS-15.4.1-arm64-arm-64bit
  • Python version: 3.12.9
  • huggingface_hub version: 0.30.2
  • PyArrow version: 19.0.1
  • Pandas version: 2.2.3
  • fsspec version: 2024.12.0

jonathanasdf avatar May 12 '25 22:05 jonathanasdf

Hi ! It's already the case IIRC: the effective seed looks like seed + worker_id. Do you have a reproducible example ?

lhoestq avatar May 14 '25 22:05 lhoestq

here is an example with shuffle

import itertools
import datasets
import multiprocessing
import torch.utils.data


def gen(shard):
  worker_info = torch.utils.data.get_worker_info()
  for i in range(10):
    yield {'value': i, 'worker_id': worker_info.id}


def main():
  ds = datasets.IterableDataset.from_generator(gen, gen_kwargs={'shard': list(range(8))})
  ds = ds.shuffle(buffer_size=100, seed=1234)
  dataloader = torch.utils.data.DataLoader(ds, batch_size=None, num_workers=8)
  for i, ex in enumerate(itertools.islice(dataloader, 50)):
    print(i, ex)


if __name__ == '__main__':
  multiprocessing.set_start_method('spawn')
  main()
python test.py
0 {'value': 8, 'worker_id': 0}
1 {'value': 8, 'worker_id': 1}
2 {'value': 8, 'worker_id': 2}
3 {'value': 8, 'worker_id': 3}
4 {'value': 8, 'worker_id': 4}
5 {'value': 8, 'worker_id': 5}
6 {'value': 8, 'worker_id': 6}
7 {'value': 8, 'worker_id': 7}
8 {'value': 9, 'worker_id': 0}
9 {'value': 9, 'worker_id': 1}
10 {'value': 9, 'worker_id': 2}
11 {'value': 9, 'worker_id': 3}
12 {'value': 9, 'worker_id': 4}
13 {'value': 9, 'worker_id': 5}
14 {'value': 9, 'worker_id': 6}
15 {'value': 9, 'worker_id': 7}
16 {'value': 5, 'worker_id': 0}
17 {'value': 5, 'worker_id': 1}
18 {'value': 5, 'worker_id': 2}
19 {'value': 5, 'worker_id': 3}

jonathanasdf avatar May 15 '25 00:05 jonathanasdf

With interleave_datasets

import itertools
import datasets
import multiprocessing
import torch.utils.data


def gen(shard, value):
  while True:
    yield {'value': value}


def main():
  ds = [
    datasets.IterableDataset.from_generator(gen, gen_kwargs={'shard': list(range(8)), 'value': i})
    for i in range(10)
  ]
  ds = datasets.interleave_datasets(ds, probabilities=[1 / len(ds)] * len(ds), seed=1234)
  dataloader = torch.utils.data.DataLoader(ds, batch_size=None, num_workers=8)
  for i, ex in enumerate(itertools.islice(dataloader, 50)):
    print(i, ex)


if __name__ == '__main__':
  multiprocessing.set_start_method('spawn')
  main()
python test.py
0 {'value': 9}
1 {'value': 9}
2 {'value': 9}
3 {'value': 9}
4 {'value': 9}
5 {'value': 9}
6 {'value': 9}
7 {'value': 9}
8 {'value': 3}
9 {'value': 3}
10 {'value': 3}
11 {'value': 3}
12 {'value': 3}
13 {'value': 3}
14 {'value': 3}
15 {'value': 3}
16 {'value': 9}
17 {'value': 9}
18 {'value': 9}
19 {'value': 9}
20 {'value': 9}
21 {'value': 9}
22 {'value': 9}
23 {'value': 9}

jonathanasdf avatar May 15 '25 01:05 jonathanasdf

Same results after updating to datasets 3.6.0.

jonathanasdf avatar May 15 '25 01:05 jonathanasdf

Ah my bad, shuffle() uses a global effective seed which is something like seed + epoch, which is used to do the same shards shuffle in each worker so that each worker have a non-overlapping set of shards:

https://github.com/huggingface/datasets/blob/b9efdc64c3bfb8f21f8a4a22b21bddd31ecd5a31/src/datasets/iterable_dataset.py#L2102-L2111

I think we should take into account the worker_id in a local seed for the buffer right after this line:

https://github.com/huggingface/datasets/blob/b9efdc64c3bfb8f21f8a4a22b21bddd31ecd5a31/src/datasets/iterable_dataset.py#L2151-L2153

like adding a new step that would propagate in the examples iterables or something like that:

ex_iterable = ex_iterable.shift_rngs(value=worker_id)

is this something you'd like to explore ? contributions on this subject are very welcome

lhoestq avatar May 15 '25 13:05 lhoestq

Potentially, but busy. If anyone wants to take this up please feel free to, otherwise I may or may not revisit when I have free time.

For what it's worth I got around this with


class SeedGeneratorWithWorkerIterable(iterable_dataset._BaseExamplesIterable):
  """ExamplesIterable that seeds the rng with worker id."""

  def __init__(
    self,
    ex_iterable: iterable_dataset._BaseExamplesIterable,
    generator: np.random.Generator,
    rank: int = 0,
  ):
    """Constructor."""
    super().__init__()
    self.ex_iterable = ex_iterable
    self.generator = generator
    self.rank = rank

  def _init_state_dict(self) -> dict:
    self._state_dict = self.ex_iterable._init_state_dict()
    return self._state_dict

  def __iter__(self):
    """Data iterator."""
    effective_seed = copy.deepcopy(self.generator).integers(0, 1 << 63) - self.rank
    effective_seed = (1 << 63) + effective_seed if effective_seed < 0 else effective_seed
    generator = np.random.default_rng(effective_seed)
    self.ex_iterable = self.ex_iterable.shuffle_data_sources(generator)
    if self._state_dict:
      self._state_dict = self.ex_iterable._init_state_dict()
    yield from iter(self.ex_iterable)

  def shuffle_data_sources(self, generator):
    """Shuffle data sources."""
    ex_iterable = self.ex_iterable.shuffle_data_sources(generator)
    return SeedGeneratorWithWorkerIterable(ex_iterable, generator=generator, rank=self.rank)

  def shard_data_sources(self, num_shards: int, index: int, contiguous=True):  # noqa: FBT002
    """Shard data sources."""
    ex_iterable = self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous)
    return SeedGeneratorWithWorkerIterable(ex_iterable, generator=self.generator, rank=index)

  @property
  def is_typed(self):
    return self.ex_iterable.is_typed

  @property
  def features(self):
    return self.ex_iterable.features

  @property
  def num_shards(self) -> int:
    """Number of shards."""
    return self.ex_iterable.num_shards

jonathanasdf avatar May 15 '25 20:05 jonathanasdf

Thanks for the detailed insights!

After reviewing the issue and the current implementation in iterable_dataset.py, I can confirm the cause:

When using interleave_datasets(..., seed=...) with num_workers > 1 (e.g. via DataLoader), the same RNG state is shared across workers — which leads to each worker producing identical sample sequences. This is because the seed is not modulated by worker_id, unlike the usual approach in shuffle() where seed is adjusted using the epoch.

As @lhoestq suggested, a proper fix would involve introducing something like:

ex_iterable = ex_iterable.shift_rngs(worker_id)

@jonathanasdf Also really appreciate the workaround implementation shared above — that was helpful to validate the behavior and will help shape the general solution.

ArjunJagdale avatar Jun 29 '25 06:06 ArjunJagdale