data icon indicating copy to clipboard operation
data copied to clipboard

Recommended way to shuffle intra and inter archives?

Open NicolasHug opened this issue 1 year ago • 8 comments

Say I have a bunch of archives containing samples. In my case each archive is a pickle file containing a list of samples, but it could be a tar or something else.

I want to shuffle between archives (inter) and within archives (intra). My current way of doing it is below. Is there a more canonical solution?

from torchdata.dataloader2 import DataLoader2, adapter
from torchdata.datapipes.iter import IterDataPipe, FileLister, IterableWrapper
from pathlib import Path

import pickle

# Create archives
root = Path("/tmp/dataset/")
with open(root / "1.pkl", "wb") as f:
    pickle.dump(list(range(10)), f)
with open(root / "2.pkl", "wb") as f:
    pickle.dump(list(range(10, 20)), f)

class PickleLoaderDataPipe(IterDataPipe):
    def __init__(self, source_datapipe):
        self.source_datapipe = source_datapipe

    def __iter__(self):
        for path in self.source_datapipe:
            with open(path, "rb") as f:
                yield pickle.load(f)  # <- this is a list

class ConcaterIterable(IterDataPipe):
    # Same as unbatch(), kinda
    def __init__(self, source_datapipe):
        self.source_datapipe = source_datapipe

    def __iter__(self):
        for iterable in self.source_datapipe:
            yield from iterable

def intra_archive_shuffle(archive_content):
    return IterableWrapper(archive_content).shuffle()
    
    
dp = FileLister(str(root), masks=["*.pkl"])
dp = dp.shuffle()  # inter-archive shuffling
dp = PickleLoaderDataPipe(dp)
dp = dp.map(intra_archive_shuffle)
dp = ConcaterIterable(dp)  # Note: unbatch doesn't work because it's a datapipe of datapipes

print(list(dp))

NicolasHug avatar Aug 15 '22 17:08 NicolasHug

Also, I don't know if this is normal or if there is a problem with DataLoader2, but if I call shuffle().set_shuffle(False) (like we do in torchvision), then the datapipe doesn't get properly shuffled: only the intra-archive shuffling happens:

# same as before, just calling `.set_shuffle(False)` to avoid shuffling by default

def intra_archive_shuffle(archive_content):
    return IterableWrapper(archive_content).shuffle().set_shuffle(False)
    
    
dp = FileLister(str(root), masks=["*.pkl"])
dp = dp.shuffle().set_shuffle(False)
dp = PickleLoaderDataPipe(dp)
dp = dp.map(intra_archive_shuffle)
dp = ConcaterIterable(dp)  # Note: unbatch doesn't work because it's a datapipe of datapipes

#print(list(dp))
dl = DataLoader2(dp, datapipe_adapter_fn=adapter.Shuffle())
print(list(dl))

The result will only ever be:

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]

or

[10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

NicolasHug avatar Aug 15 '22 17:08 NicolasHug

OK so regarding the use of shuffle().set_shuffle(False) in the message above, the problem comes from traverse():

traverse(dp)
{ConcaterIterable: {MapperIterDataPipe: {PickleLoaderDataPipe: {ShufflerIterDataPipe: {FileListerIterDataPipe: {IterableWrapperIterDataPipe: {}}}}}}}

The intra-archive shufflers are "hidden" by the MapperIterDataPipe, and so they are not re-activated by the adapter.

There's probably nothing traverse() can do here... Or maybe there is? @VitalyFedyunin would you like me to open a more specific issue to keep track of this particular issue?

I wonder what it means w.r.t. to users code though. A few outstanding questions on my side:

  • Do we need to tell users to never create datapipes in a call to map()?
  • Are there other blind spots like this one?
  • Going back to my initial question from this post, since my snippet doesn't actually work: What is a good canonical way to shuffle intra and inter archives? 😅

NicolasHug avatar Aug 15 '22 17:08 NicolasHug

Do we need to tell users to never create datapipes in a call to map()?

The problem is more general in that it applies to any kind of wrapper that generates a datapipe, not just to a Mapper. For example the solution below still has the exact same problem because the Shuffler instances are hidden within the PickleLoaderDataPipe:

class PickleLoaderDataPipe(IterDataPipe):
    def __init__(self, source_datapipe):
        self.source_datapipe = source_datapipe

    def __iter__(self):
        for path in self.source_datapipe:
            with open(path, "rb") as f:
                yield IterableWrapper(pickle.load(f)).shuffle().set_shuffle(False)  # <- shuffling won't work

NicolasHug avatar Aug 15 '22 18:08 NicolasHug

What is a good canonical way to shuffle intra and inter archives?

I think the best way is to use in_batch_shuffle. Though we need to add further randomness control to it if we want it to be enabled/disabled.

cc: @ejguan

NivekT avatar Aug 15 '22 19:08 NivekT

The problem is more general in that it applies to any kind of wrapper that generates a datapipe, not just to a Mapper. For example the solution below still has the exact same problem because the Shuffler instances are hidden within the PickleLoaderDataPipe:

I feel like this is not solvable by traverse function. I think we should DataPipe as a static graph of operations. The PickleLoader adds DataPipe in the runtime, which makes torchdata no way to track the DAG.

ejguan avatar Aug 15 '22 19:08 ejguan

I think question here is how big is the pickled data. If it fits memory nicely (I presume it just have too, due to the Pickle nature) the best case would be to shuffle list using standard python functions:

import random
def intra_archive_shuffle(archive_content):
    random.shuffle(archive_content)
    return archive_content

Also in_batch_shuffle does exactly this so it is preferable way. Note: It is not sensitive to shuffle enables/disables, let us know if it is the issue.

If in other case (let say you are using tar or any other stream unpacks), it is better to yield item by item and use multiple instances of the shuffle:

dp = FileLister(str(root), masks=["*.pkl"])
dp = dp.shuffle().set_shuffle(False)
dp = PickleLoaderDataPipe(dp)
dp = dp.shuffle() # yes again, play with buffer size here

Also:

Now I can see where DataPipes of DataPipes is comming from, and want to note that it is anti-pattern for DP graphs and we should include linter for it.

VitalyFedyunin avatar Aug 15 '22 20:08 VitalyFedyunin

Thanks a lot everyone for your input!

dp = dp.shuffle() # yes again, play with buffer size here

Ah that's a great idea, if we set buffer_size=archive_size it should do what I want to do (shuffling is a tiny bit different but it doesn't matter, it's actually better). However, we would need to know the size of the archives, which might be tricky / cumbersome in general?

Looks like in_batch_shuffle would be a good alternative, however I do need the ability to disable shuffling because the same code-path would be used for the training sets (shuffle enabled) and the tests sets (shuffle disabled).

NicolasHug avatar Aug 16 '22 08:08 NicolasHug

Looks like in_batch_shuffle would be a good alternative, however I do need the ability to disable shuffling because the same code-path would be used for the training sets (shuffle enabled) and the tests sets (shuffle disabled).

Yeah. This is one of TODOs on top of my lists. I need to align in_batch_shuffle with shuffle in terms of API.

Edit: shuffle_setting should also respect those two types of DataPipe.

ejguan avatar Aug 16 '22 13:08 ejguan