datasets icon indicating copy to clipboard operation
datasets copied to clipboard

Caching shuffles by np.random.Generator results in unintiutive behavior

Open el-hult opened this issue 6 months ago • 5 comments

Describe the bug

Create a dataset. Save it to disk. Load from disk. Shuffle, usning a np.random.Generator. Iterate. Shuffle again. Iterate. The iterates are different since the supplied np.random.Generator has progressed between the shuffles.

Load dataset from disk again. Shuffle and Iterate. See same result as before. Shuffle and iterate, and this time it does not have the same shuffling as ion previous run.

The motivation is I have a deep learning loop with

for epoch in range(10):
    for batch in dataset.shuffle(generator=generator).iter(batch_size=32):
        .... # do stuff

where I want a new shuffling at every epoch. Instead I get the same shuffling.

Steps to reproduce the bug

Run the code below two times.

import datasets
import numpy as np

generator = np.random.default_rng(0)
ds = datasets.Dataset.from_dict(mapping={"X":range(1000)})
ds.save_to_disk("tmp")
print("First loop: ", end="")
for _ in range(10):
    print(next(ds.shuffle(generator=generator).iter(batch_size=1))['X'], end=", ")
print("")

print("Second loop: ", end="")
ds = datasets.Dataset.load_from_disk("tmp")
for _ in range(10):
    print(next(ds.shuffle(generator=generator).iter(batch_size=1))['X'], end=", ")
print("")

The output is:

$ python main.py 
Saving the dataset (1/1 shards): 100%|███████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 495019.95 examples/s]
First loop: 459, 739, 72, 943, 241, 181, 845, 830, 896, 334, 
Second loop: 741, 847, 944, 795, 483, 842, 717, 865, 231, 840,
$ python main.py 
Saving the dataset (1/1 shards): 100%|████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 22243.40 examples/s]
First loop: 459, 739, 72, 943, 241, 181, 845, 830, 896, 334, 
Second loop: 741, 741, 741, 741, 741, 741, 741, 741, 741, 741, 

The second loop, on the second run, only spits out "741, 741, 741...." which is not the desired output

Expected behavior

I want the dataset to shuffle at every epoch since I provide it with a generator for shuffling.

Environment info

Datasets version 2.21.0 Ubuntu linux.

el-hult avatar Aug 26 '24 10:08 el-hult