data icon indicating copy to clipboard operation
data copied to clipboard

Recommended practice to shuffle data with datapipes differently every epoch

Open BarclayII opened this issue 1 year ago • 4 comments

📚 The doc issue

I was trying torchdata 0.4.0 and I found that shuffling with data pipes will always yield the same result across different epochs, unless I shuffle it again at the beginning of every epoch.

# same_result.py
import torch
import torchdata.datapipes as dp
X = torch.randn(200, 5)
dpX = dp.map.SequenceWrapper(X)
dpXS = dpX.shuffle()
for _ in range(5):
    for i in dpXS:
        print(i)   # always prints the same value
        break

# different_result.py
import torch
import torchdata.datapipes as dp
X = torch.randn(200, 5)
dpX = dp.map.SequenceWrapper(X)
for _ in range(5):
    dpXS = dpX.shuffle()
    for i in dpXS:
        print(i)   # prints different values
        break

I wonder what is the recommended practice to shuffle the data at the beginning of every epoch? Neither the documentation nor the examples seem to answer this question.

Suggest a potential alternative/fix

No response

BarclayII avatar Aug 05 '22 02:08 BarclayII

Technically speaking, you should use DataLoader to work with DataPipe to enable shuffling with different randomness of every epoch. But, there are a few bugs in DataLoader and map.SequenceWrapper, I will send a patch to fix it.

Besides, it seems the behavior of MapDataPipe is different from IterDataPipe:

  • IterDataPipe.shuffle would shuffle lazily during iteration: shuffle order is changed across epoch
  • MapDataPipe.shuffle would shuffle at the construction time: shuffle order is fixed We need to find a way to align them together

ejguan avatar Aug 05 '22 13:08 ejguan

Could you please try to use IterDataPipe via dp.iter.IterableWrapper and provide the datapipe to DataLoader as a temporary workaround?

ejguan avatar Aug 05 '22 13:08 ejguan

This is what I have now:

X = torch.randn(200, 5)
dpX = dp.iter.IterableWrapper(X)
dpXS = dpX.shuffle(buffer_size=X.shape[0])
dl = torch.utils.data.DataLoader(dpXS)
for _ in range(5):
    for i in dl:
        print(i)
        break

It seems that to emulate the old DataLoader's shuffle=True behavior I'll need to pass in the size of the dataset explicitly into the shuffle() method.

BarclayII avatar Aug 06 '22 04:08 BarclayII

The shuffle for IterDataPipe has to be a buffered shuffle as there isn't a concept of indices. So, in order to achieve global shuffle, you have to provide the size of X.shape[0] to shuffle op.

ejguan avatar Aug 08 '22 13:08 ejguan