webdataset icon indicating copy to clipboard operation
webdataset copied to clipboard

Unexpected Shuffling Behavior

Open mitchellnw opened this issue 4 years ago • 13 comments

Even with buffer size and initial are set to the size of the dataset, I am not seeing completely a completely shuffled dataset.

I have created a toy dataset with 10000 integers split into 100 shards.

A random batch looks something like:

tensor([[3251, 6459, 3253, 6470, 3256, 8078, 3257, 3255, 3259, 4836],
        [ 151,  149,  153, 9749,  155,  139, 8186, 8174, 6542, 1741],
        [9874, 1825, 8253, 6642, 8255, 6630, 8257, 8246, 8259,  218],
        [1998, 3512, 3553, 3520, 3555, 9930, 3557, 3500, 3572, 8324]])

and many integers are close together (e.g. <= 20 apart).

Why is this? Is there any way to get around this? I am doing contrastive learning so diverse batches is very important.

The sample code I am using is

import torch
import os
import numpy as np
import webdataset as wds
import torch.distributed as dist
import torch.multiprocessing as mp
from webdataset.iterators import batched


def train(gpu, ngpus_per_node, dataloader):

    for epoch in range(3):
        if gpu == 0:
            print(f'\n----------- Epoch = {epoch} ----------\n')
        for i, x in enumerate(dataloader):
            xarr = torch.from_numpy(np.array(x).astype('int')).to(gpu)
            gathered_xarr = [
                torch.zeros_like(xarr) for _ in range(ngpus_per_node)
            ]
            dist.all_gather(gathered_xarr, xarr)
            gathered_xarr = torch.cat(gathered_xarr)

            if (i % 100) == 0:
                if gpu == 0:
                    print(gathered_xarr.cpu())

def main_worker(gpu, ngpus_per_node):

    dist.init_process_group(
        backend="nccl",
        init_method="tcp://127.0.0.1:6100",
        world_size=ngpus_per_node,
        rank=gpu,
    )


    input_shards = 'shard_{00..99}.tar'
    size = 10000
    batch_size = 10
    workers = 4
    num_batches = size // batch_size
    dataset = wds.WebDataset(input_shards, length=num_batches, shardshuffle=True)
    dataset = dataset.shuffle(size, initial=size)
    dataset = (
        dataset
        .decode("pil")
        .rename(text="txt")
        .to_tuple("text")
    )
    dataset = dataset.batched(batch_size)
    dataloader = wds.WebLoader(
        dataset, batch_size=None, shuffle=False, num_workers=workers,
    )

    train(gpu, ngpus_per_node, dataloader)

def main():

    torch.multiprocessing.set_start_method("spawn")

    ngpus_per_node = 4
    mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node,))

if __name__ == "__main__":
    main()

Any suggestions would be very appreciated.

Sorry I initially wrote this in another issue, which I did not realize was closed. https://github.com/webdataset/webdataset/issues/62

mitchellnw avatar May 18 '21 01:05 mitchellnw

+1, having similar issues here. Looking at the original shuffling code (below), I have some questions:

  1. why is the buffer not shuffled before the final for loop?
  2. in the try portion, why is next called? Instead of appending sample?
    initial = min(initial, bufsize)
    buf = []
    startup = True
    for sample in data:
        if len(buf) < bufsize:
            try:
                buf.append(next(data))  # skipcq: PYL-R1708
            except StopIteration:
                pass
        k = rng.randint(0, len(buf) - 1)
        sample, buf[k] = buf[k], sample
        if startup and len(buf) < initial:
            buf.append(sample)
            continue
        startup = False
        yield sample
    for sample in buf:
        yield sample

Would something like this make sense?

    initial = min(initial, bufsize)
    buf = []
    startup = True
    for sample in data:
        if len(buf) < bufsize:
            buf.append(sample)
        if startup and len(buf) < initial:
            continue
        k = rng.randint(0, len(buf) - 1)
        sample, buf[k] = buf[k], sample
        startup = False
        yield sample
    random.shuffle(buf)
    for sample in buf:
        yield sample

Kindly appreciate your input here!

gabrielilharco avatar May 18 '21 17:05 gabrielilharco

Why you are seeing the bunching.

You have 100 shards of 100 samples each. You also have 16 processes reading those shards. That means that you only have 6 shards per worker and that's where all the samples for that worker come from during one epoch. Your shuffle buffer of 10000 samples will only ever contain 600 samples. Each of these six shards will sample from a range of 100 values, and you construct batches within the workers, so it's not surprising that you should see such clusters of values. For large training sets, this clustering of values goes away. In practice, it isn't even much of a problem for small training sets.

However, if you want to avoid it for small training sets, there are several things you can do:

Solution: Don't split by shard and instead select by slice.

This does roughly what indexed datasets do and is about as efficient; it's probably the most straightforward answer. I may add a couple of functions to WebDataset to make this easier.

offset = worker_id + num_workers * node_id
splitsize = num_workes * world_size
dataset = WebDataset(..., nodesplitter=identity, splitter=identity).slice(offset, 999999, splitsize).shuffle(...).decode(...)

Solution: Shuffle and rebatch after the loader.

This requires minimal changes and gives you shuffling between workers (but not between nodes).

loader.unbatched().shuffle(1000).batched(batch_size)

Solution: Sample with replacement.

This is actually what I recommend. It's simpler, works better for both small and large datasets, and has some statistical advantages as well. However, it is not common practice.

dataset = WebDataset(..., nodesplitter=identity, splitter=identity).repeat().shuffle(...).decode(...).batched(...)
loader = WebLoader(dataset, ...)

If you're converting code that is uses an epoch-drive main loop, you can use ResizedDataset to impose "epochs" on this infinite data stream.

tmbdev avatar May 18 '21 17:05 tmbdev

why is the buffer not shuffled before the final for loop?

It makes sense to add that for clarity, though the existing logic already effectively shuffles the buffer. The case doesn't usually come up, since the shuffle buffer is usually smaller than the amount of data in the worker.

in the try portion, why is next called? Instead of appending sample?

Your code appends the sample and then uses the same sample to shuffle it into the buffer, leading to two copies of the sample being present in the shuffle buffer.

The existing code appends an extra sample for each sample that is returned until the buffer is full. The code that does that is inserted by itself at the beginning of the loop and doesn't affect the rest of the loop otherwise.

tmbdev avatar May 20 '21 00:05 tmbdev

Thank you very much. In the end this code also helped resolve the issue:

https://github.com/webdataset/webdataset/pull/73

Is there any plans on enabling this functionality with WebDataset? It turns out having the same workers accessing the same shards every epoch leads to drastically worse performance in Contrastive Learning.

mitchellnw avatar May 24 '21 20:05 mitchellnw

@mitchellnw I think these issues are probably more issues with documentation than missing functionality. For example, you can incorporate epoch-specific shuffling into a nodesplitter function; alternatively, these problems simply go away entirely if you use sampling with replacement or all-shards-on-all-nodes.

I'll see what I can do to document this better and maybe add explicit support for set_epoch.

tmbdev avatar May 25 '21 02:05 tmbdev

Thank you so much! It's been great to use WebDataset so far and would really appreciate an example of epoch-specific shuffling with a nodesplitter or explicit support for set_epoch.

Unfortunately all-shards-on-all-nodes i.i.d. sampling was surprisingly not as accurate -- perhaps this is why people don't usually do this.

mitchellnw avatar May 25 '21 03:05 mitchellnw

Roughly, the way PyTorch handles spawn is that the workers are spawned in subprocesses where torch.distributed is not initialized; in addition, the Dataset instances are restored from pickled data on every epoch. I think the all-shards-on-all-nodes performance differences may be due to that as well. That also means that set_epoch in your patch wouldn't work as expected. The usual arguments to DataLoader (shuffle, sampler, batch_sampler) are explicitly disabled in DataLoader for IterableDataset, so they don't help us.

I have refactored the code now in a way in which this should work more like you expect and that does also allow you to use per-epoch shuffling.

If you don't need per-epoch shuffling of shards between workers, just don't do anything; these defaults should now work well enough.

If you do want epoch_shuffle, you can create the dataset instance as follows:

# (updated to new API)
shardlist = wds.PytorchShardList("imagenet-{000000..000015}.tgz", epoch_shuffle=True)
dataset = wds.WebDataset(shardlist, ...)...
loader = wds.WebLoader(dataset, num_workers=4, batch_size=20)

And then you must put the epoch in the environment:

    loader = make_loader()
    for epoch in range(2):
        os.environ["WDS_EPOCH"] = str(epoch)
        for inputs, labels in loader:
            ...

You can find a preliminary example notebook here

https://github.com/webdataset/webdataset/blob/master/notebooks/multinode-test.ipynb

tmbdev avatar May 30 '21 08:05 tmbdev

@tmbdev How do you feel about using something like torch's samplers for shuffling? There are several sampling techniques available, eg sequential, random (with and without replacement), weighted.

mynameisvinn avatar Jun 01 '21 13:06 mynameisvinn

In WebDataset, data is primarily shuffled at the shard level, and you get the equivalent of PyTorch's samplers now: that's the shardlist= argument. Your shardlist class can sample the shards any way it likes. A node-splitting sampler and sampling with replacement are provided. I will be adding a weighted shard sampler in the future, since it's needed for very large scale training.

Separately, you can also perform shuffling and resampling at the per sample level, but those need to be streaming operations. Those operations are found as filters or operations on datasets; for example, dataset.rsample lets you subsample randomly. I will probably add a class frequency equalizing stream sampler at some point.

WebDataset is a massively parallel I/O library, and that necessitates that you deal with parallel (shards) and serial (samplers) aspects separately. The WebDataset library is further constrained by how PyTorch's dataloader handles I/O processes.

(Note that if you have really complex I/O, shuffling, or augmentation needs, you might use something like nvlabs/tensorcom, which gives you complete freedom in shuffling samples between GPU nodes.)

tmbdev avatar Jun 01 '21 21:06 tmbdev

thanks @tmbdev for the comprehensive answer!

mynameisvinn avatar Jun 03 '21 10:06 mynameisvinn

Hi @tmbdev, unfortunately the old fix from your comment above is no longer working for me. Seems this is maybe due to some refactoring? Any guidance on what to do instead? Thank you!

mitchellnw avatar Jun 04 '21 17:06 mitchellnw

Yes, sorry. You can now simply write:

shardlist = wds.PytorchShardList("imagenet-{000000..000015}.tgz", epoch_shuffle=True)
dataset = wds.WebDataset(shardlist, ...)...
loader = wds.WebLoader(dataset, num_workers=4, batch_size=20)

That is, you can either pass in URLs or a shardlist object into the Dataset constructor (shardlists are just IterableDatasets themselves). The rest of the code should work as before.

tmbdev avatar Jun 05 '21 01:06 tmbdev

Hi, @tmbdev have you implemented the weighted sampler yet? I am most interested in sampling within each shard, is it easily doable? Thanks!

EEthinker avatar Feb 03 '22 19:02 EEthinker

You can use RandomMix to mix sources with arbitrary probabilities, or you can use MultiShardSmaple to sample at the shard level.

tmbdev avatar Mar 05 '23 21:03 tmbdev

"Separately, you can also perform shuffling and resampling at the per sample level, but those need to be streaming operations. Those operations are found as filters or operations on datasets; for example, dataset.rsample lets you subsample randomly. I will probably add a class frequency equalizing stream sampler at some point." Does this mean you lose the speed benefits of web dataset if you decide to do shuffling on the per sample level? Not sure what "streaming" means but I guess it means breaking the serial-within-shard abstraction and accessing samples individually.

richardrl avatar Apr 15 '24 06:04 richardrl