webdataset
webdataset copied to clipboard
Buckets of elements of different sizes
I am attempting to use webdataset to support loading a dataset with subsections/buckets organized by image size. To do this, I've organized the files such that each bucket has its own designated folder, each filled with tars of images of that size, then create a webdataset object for each. then each dataset is made into an iterator object just calling datasets[i] = iter(datasets[i]), and then at each step of training, sample a random index. weighted by the number of elements we have precalculated, and call next() to get a batch
However this seems to be causing memory issues that scale in some unusual ways. On
- 2 nodes, 6 workers, the training survives 1 step before running out of memory
- 2 nodes, 2 workers, it survives for 6 steps
- 2 nodes, 1 worker we get 14 steps in
- and at 1 node, 2 worker you can get in a whole 270 steps.
looking a cpu memory usage, it seems to increase ~7gb with each step give or take, but given a 1TB RAM allownace it often dies before hitting that limit
Have tried some other experiments to better understand this as well. for instance, trying out only 2 datasets, and trying to alternate between the two every 4 steps. To get a sense of, is it the switching itself thats causing the issue? does memory continue to increase with each switch? after switching does it stay constant.
Does this seem like a logical way to approach what I am trying to do? If not, are there any other possible solutions through webdataset? Previously having worked with dataloaders that sample from subbuckets, I've used the map style datasets through torch making use of the sampler object you can pass through
Were you able to solve this issue? I am running into somewhat of a similar issue where I notice that my CPU memory usage increases but is nothing drastic (way below my RAM limits). However, I noticed that the swap space usage increases exorbitantly leading to my jobs getting killed. Here are some plots showing the problem.
The green lines below showcase a working training run where everything is smooth (this is not with a webdataset loader but standard dataloader with ImageFolder datasets). However, in the red lines, I have both the image data-loader and a webdataset loader loading data together, and then I cycle through them repeatedly. At the point where the red lines stop, my training job gets killed. From the CPU RAM plots, nothing untoward seems to be happening, there is an initial spike but then things stabilise, similar to the green line. However, when I track the swap space (not shown in the plots), it seems to be rapidly going down, until it becomes 0, and then a few seconds later the job crashes.
The data-loading code for the green line looks something like this:
# Start Training.
for ind, data in enumerate(train_loader):
batch_size = len(data["targets"])
and for the red line looks like:
dl = get_wds_dataset(
'./data/{0000..0090.tar}',
preprocess_img=transform,
preprocess_label=lambda x: x,
is_train=True,
epoch=0,
batch_size=512,
num_samples=1_000_000,
resampled=False,
seed=0,
workers=2,
)
# Start Training.
for ind, data in enumerate(train_loader):
batch_size = len(data["targets"])
where get_wds_dataset
is the function from here: https://github.com/mlfoundations/open_clip/blob/2e8de8312ea6185df7bc24a73b19a195a801a9fc/src/training/data.py#L328
I have tried running the script with workers=0
and it still fails and gets killed after a while. Does anyone have any ideas on how to resolve this?
update: this was because of the itertools.cycle function leading to memory leaks (https://github.com/pytorch/pytorch/issues/23900). Without the cycle function everything works smoothly. This was not a webdataset issue.
Sorry for the long delay.
WebDataset does not retain any samples unless you ask it to explicitly (e.g. in a shuffle buffer).
In you case, just be sure not to use shuffle
in each of the component datasets; instead, create a new IterableDataset and then pipeline that with the shuffle method. The code looks something like this:
class MyBucketSampler(IterableDataset):
...
def __iter__(self):
...
bucketsampler = MyBucketSampler(...)
dataset = wds.DataPipeline(
bucketsampler,
wds.shuffle(...),
wds.batched(...),
)
I think your code bucketsampler
should be an iterator that randomly samples from all the tarfiles of every bucket. In your demo code, the data is merely fetched randomly, shuffled and collated, but it doesn’t ensure that all the data in a single batch comes from the same bucket. Such an approach won’t achieve the purpose of bucket training.
And I believe @ethansmith2000's 2-level shuffling approach is well, and avoiding shuffle in each of the component datasets is not feasible