webdataset
webdataset copied to clipboard
Unexpected Shuffling Behavior
Even with buffer size and initial are set to the size of the dataset, I am not seeing completely a completely shuffled dataset.
I have created a toy dataset with 10000 integers split into 100 shards.
A random batch looks something like:
tensor([[3251, 6459, 3253, 6470, 3256, 8078, 3257, 3255, 3259, 4836],
[ 151, 149, 153, 9749, 155, 139, 8186, 8174, 6542, 1741],
[9874, 1825, 8253, 6642, 8255, 6630, 8257, 8246, 8259, 218],
[1998, 3512, 3553, 3520, 3555, 9930, 3557, 3500, 3572, 8324]])
and many integers are close together (e.g. <= 20 apart).
Why is this? Is there any way to get around this? I am doing contrastive learning so diverse batches is very important.
The sample code I am using is
import torch
import os
import numpy as np
import webdataset as wds
import torch.distributed as dist
import torch.multiprocessing as mp
from webdataset.iterators import batched
def train(gpu, ngpus_per_node, dataloader):
for epoch in range(3):
if gpu == 0:
print(f'\n----------- Epoch = {epoch} ----------\n')
for i, x in enumerate(dataloader):
xarr = torch.from_numpy(np.array(x).astype('int')).to(gpu)
gathered_xarr = [
torch.zeros_like(xarr) for _ in range(ngpus_per_node)
]
dist.all_gather(gathered_xarr, xarr)
gathered_xarr = torch.cat(gathered_xarr)
if (i % 100) == 0:
if gpu == 0:
print(gathered_xarr.cpu())
def main_worker(gpu, ngpus_per_node):
dist.init_process_group(
backend="nccl",
init_method="tcp://127.0.0.1:6100",
world_size=ngpus_per_node,
rank=gpu,
)
input_shards = 'shard_{00..99}.tar'
size = 10000
batch_size = 10
workers = 4
num_batches = size // batch_size
dataset = wds.WebDataset(input_shards, length=num_batches, shardshuffle=True)
dataset = dataset.shuffle(size, initial=size)
dataset = (
dataset
.decode("pil")
.rename(text="txt")
.to_tuple("text")
)
dataset = dataset.batched(batch_size)
dataloader = wds.WebLoader(
dataset, batch_size=None, shuffle=False, num_workers=workers,
)
train(gpu, ngpus_per_node, dataloader)
def main():
torch.multiprocessing.set_start_method("spawn")
ngpus_per_node = 4
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node,))
if __name__ == "__main__":
main()
Any suggestions would be very appreciated.
Sorry I initially wrote this in another issue, which I did not realize was closed. https://github.com/webdataset/webdataset/issues/62
+1, having similar issues here. Looking at the original shuffling code (below), I have some questions:
- why is the buffer not shuffled before the final for loop?
- in the
tryportion, why isnextcalled? Instead of appendingsample?
initial = min(initial, bufsize)
buf = []
startup = True
for sample in data:
if len(buf) < bufsize:
try:
buf.append(next(data)) # skipcq: PYL-R1708
except StopIteration:
pass
k = rng.randint(0, len(buf) - 1)
sample, buf[k] = buf[k], sample
if startup and len(buf) < initial:
buf.append(sample)
continue
startup = False
yield sample
for sample in buf:
yield sample
Would something like this make sense?
initial = min(initial, bufsize)
buf = []
startup = True
for sample in data:
if len(buf) < bufsize:
buf.append(sample)
if startup and len(buf) < initial:
continue
k = rng.randint(0, len(buf) - 1)
sample, buf[k] = buf[k], sample
startup = False
yield sample
random.shuffle(buf)
for sample in buf:
yield sample
Kindly appreciate your input here!
Why you are seeing the bunching.
You have 100 shards of 100 samples each. You also have 16 processes reading those shards. That means that you only have 6 shards per worker and that's where all the samples for that worker come from during one epoch. Your shuffle buffer of 10000 samples will only ever contain 600 samples. Each of these six shards will sample from a range of 100 values, and you construct batches within the workers, so it's not surprising that you should see such clusters of values. For large training sets, this clustering of values goes away. In practice, it isn't even much of a problem for small training sets.
However, if you want to avoid it for small training sets, there are several things you can do:
Solution: Don't split by shard and instead select by slice.
This does roughly what indexed datasets do and is about as efficient; it's probably the most straightforward answer. I may add a couple of functions to WebDataset to make this easier.
offset = worker_id + num_workers * node_id
splitsize = num_workes * world_size
dataset = WebDataset(..., nodesplitter=identity, splitter=identity).slice(offset, 999999, splitsize).shuffle(...).decode(...)
Solution: Shuffle and rebatch after the loader.
This requires minimal changes and gives you shuffling between workers (but not between nodes).
loader.unbatched().shuffle(1000).batched(batch_size)
Solution: Sample with replacement.
This is actually what I recommend. It's simpler, works better for both small and large datasets, and has some statistical advantages as well. However, it is not common practice.
dataset = WebDataset(..., nodesplitter=identity, splitter=identity).repeat().shuffle(...).decode(...).batched(...)
loader = WebLoader(dataset, ...)
If you're converting code that is uses an epoch-drive main loop, you can use ResizedDataset to impose "epochs" on this infinite data stream.
why is the buffer not shuffled before the final for loop?
It makes sense to add that for clarity, though the existing logic already effectively shuffles the buffer. The case doesn't usually come up, since the shuffle buffer is usually smaller than the amount of data in the worker.
in the try portion, why is next called? Instead of appending sample?
Your code appends the sample and then uses the same sample to shuffle it into the buffer, leading to two copies of the sample being present in the shuffle buffer.
The existing code appends an extra sample for each sample that is returned until the buffer is full. The code that does that is inserted by itself at the beginning of the loop and doesn't affect the rest of the loop otherwise.
Thank you very much. In the end this code also helped resolve the issue:
https://github.com/webdataset/webdataset/pull/73
Is there any plans on enabling this functionality with WebDataset? It turns out having the same workers accessing the same shards every epoch leads to drastically worse performance in Contrastive Learning.
@mitchellnw I think these issues are probably more issues with documentation than missing functionality. For example, you can incorporate epoch-specific shuffling into a nodesplitter function; alternatively, these problems simply go away entirely if you use sampling with replacement or all-shards-on-all-nodes.
I'll see what I can do to document this better and maybe add explicit support for set_epoch.
Thank you so much! It's been great to use WebDataset so far and would really appreciate an example of epoch-specific shuffling with a nodesplitter or explicit support for set_epoch.
Unfortunately all-shards-on-all-nodes i.i.d. sampling was surprisingly not as accurate -- perhaps this is why people don't usually do this.
Roughly, the way PyTorch handles spawn is that the workers are spawned in subprocesses where torch.distributed is not initialized; in addition, the Dataset instances are restored from pickled data on every epoch. I think the all-shards-on-all-nodes performance differences may be due to that as well. That also means that set_epoch in your patch wouldn't work as expected. The usual arguments to DataLoader (shuffle, sampler, batch_sampler) are explicitly disabled in DataLoader for IterableDataset, so they don't help us.
I have refactored the code now in a way in which this should work more like you expect and that does also allow you to use per-epoch shuffling.
If you don't need per-epoch shuffling of shards between workers, just don't do anything; these defaults should now work well enough.
If you do want epoch_shuffle, you can create the dataset instance as follows:
# (updated to new API)
shardlist = wds.PytorchShardList("imagenet-{000000..000015}.tgz", epoch_shuffle=True)
dataset = wds.WebDataset(shardlist, ...)...
loader = wds.WebLoader(dataset, num_workers=4, batch_size=20)
And then you must put the epoch in the environment:
loader = make_loader()
for epoch in range(2):
os.environ["WDS_EPOCH"] = str(epoch)
for inputs, labels in loader:
...
You can find a preliminary example notebook here
https://github.com/webdataset/webdataset/blob/master/notebooks/multinode-test.ipynb
@tmbdev How do you feel about using something like torch's samplers for shuffling? There are several sampling techniques available, eg sequential, random (with and without replacement), weighted.
In WebDataset, data is primarily shuffled at the shard level, and you get the equivalent of PyTorch's samplers now: that's the shardlist= argument. Your shardlist class can sample the shards any way it likes. A node-splitting sampler and sampling with replacement are provided. I will be adding a weighted shard sampler in the future, since it's needed for very large scale training.
Separately, you can also perform shuffling and resampling at the per sample level, but those need to be streaming operations. Those operations are found as filters or operations on datasets; for example, dataset.rsample lets you subsample randomly. I will probably add a class frequency equalizing stream sampler at some point.
WebDataset is a massively parallel I/O library, and that necessitates that you deal with parallel (shards) and serial (samplers) aspects separately. The WebDataset library is further constrained by how PyTorch's dataloader handles I/O processes.
(Note that if you have really complex I/O, shuffling, or augmentation needs, you might use something like nvlabs/tensorcom, which gives you complete freedom in shuffling samples between GPU nodes.)
thanks @tmbdev for the comprehensive answer!
Hi @tmbdev, unfortunately the old fix from your comment above is no longer working for me. Seems this is maybe due to some refactoring? Any guidance on what to do instead? Thank you!
Yes, sorry. You can now simply write:
shardlist = wds.PytorchShardList("imagenet-{000000..000015}.tgz", epoch_shuffle=True)
dataset = wds.WebDataset(shardlist, ...)...
loader = wds.WebLoader(dataset, num_workers=4, batch_size=20)
That is, you can either pass in URLs or a shardlist object into the Dataset constructor (shardlists are just IterableDatasets themselves). The rest of the code should work as before.
Hi, @tmbdev have you implemented the weighted sampler yet? I am most interested in sampling within each shard, is it easily doable? Thanks!
You can use RandomMix to mix sources with arbitrary probabilities, or you can use MultiShardSmaple to sample at the shard level.
"Separately, you can also perform shuffling and resampling at the per sample level, but those need to be streaming operations. Those operations are found as filters or operations on datasets; for example, dataset.rsample lets you subsample randomly. I will probably add a class frequency equalizing stream sampler at some point." Does this mean you lose the speed benefits of web dataset if you decide to do shuffling on the per sample level? Not sure what "streaming" means but I guess it means breaking the serial-within-shard abstraction and accessing samples individually.