litdata icon indicating copy to clipboard operation
litdata copied to clipboard

Using a streaming dataloader with an unbalanced dataset yields unexpected batch sizes.

Open esivonxay-cognitiv opened this issue 1 year ago • 10 comments

🐛 Bug

I have two datasets which are unbalanced, where one dataset is 1000x larger than the other. I would like to sample from two of the datasets such that the ratio of samples from each is 1:100. When doing so, the batches are of irregular size are returned during iteration.

I think there are 2 issues which this test surfaces:

  1. The first batch returned by each worker is not properly sized.
  2. drop_last does not appear to work as intended, since the last batch is not a full sized batch

I don't think this is related to #179, but it's possible

I've been attempting to fix this, but I'm not sure what the root of the issue is. I would be very appreciative if you could fix this or point me in the right direction.

Thanks!

To Reproduce

@pytest.mark.skipif(sys.platform == "win32", reason="too slow in CI")
def test_unbalanced_combined_dataset_with_dataloader(tmpdir):
    data_dir_1 = os.path.join(tmpdir, "data_1")
    data_dir_2 = os.path.join(tmpdir, "data_2")
    cache_dir_1 = os.path.join(tmpdir, "cache_dir_1")
    cache_dir_2 = os.path.join(tmpdir, "cache_dir_2")

    os.makedirs(data_dir_1)
    os.makedirs(data_dir_2)
    os.makedirs(cache_dir_1)
    os.makedirs(cache_dir_2)

    cache = Cache(input_dir=str(data_dir_1), chunk_size=2)

    for i in range(10):
        cache[i] = i

    cache.done()
    cache.merge()

    cache = Cache(input_dir=str(data_dir_2), chunk_size=2)

    for i in range(10000):
        cache[i] = i + 10

    cache.done()
    cache.merge()

    dataset1 = StreamingDataset(input_dir=Dir(cache_dir_1, data_dir_1), shuffle=True)
    dataset2 = StreamingDataset(input_dir=Dir(cache_dir_2, data_dir_2), shuffle=True)
    dataset = CombinedStreamingDataset(
        datasets=[dataset1, dataset2], weights=[0.01, 0.99], iterate_over_all=False, seed=12345
    )
    dataloader = StreamingDataLoader(dataset, num_workers=3, batch_size=100, drop_last=True, persistent_workers=True, shuffle=True, prefetch_factor=2)

    assert dataset1.current_epoch == 1
    assert dataset2.current_epoch == 1

    batches_1 = []
    batch_sizes_1 = []
    for batch in dataloader:
        batch_sizes_1.append(batch.size(0))
        batches_1.append(batch)

    assert batch_sizes_1[2] == 91
    assert batch_sizes_1[-1] == 40
    # This will fail since the third and last index are not 100. (Above 2 assertions pass)
    assert batch_sizes_1 == [100 for _ in batches_1]

Expected behavior

All batch sizes should be the same.

Additional context

This issue is independent of whether drop_last, shuffle, and persistent_workers are set to True or False

esivonxay-cognitiv avatar Jun 29 '24 18:06 esivonxay-cognitiv

Hey @esivonxay-cognitiv, Thanks for the reproducible script. I will have a look into it.

tchaton avatar Jun 30 '24 07:06 tchaton

Thanks Thomas!

esivonxay-cognitiv avatar Jun 30 '24 15:06 esivonxay-cognitiv

Hey @esivonxay-cognitiv I am curious, what's your interest and usage of LitData ?

tchaton avatar Jun 30 '24 17:06 tchaton

Yeah, I'm interested in LitData primarily for the ability to sample from multiple streams. I've got 2 datasets which are quite imbalanced (one is 100,000x larger than the other) and I'm trying to downsample one dataset to reduce the imbalance by a couple orders of magnitude.

Naively, I could do this when constructing the dataset by throwing out datapoints. However, doing so will result in me throwing out 90 or 99% of the data (to decrease the imbalance by 10x or 100x, respectively). It's possible that important samples may be thrown out in this process.

My thought was to do this downsampling/rebalancing during dataloading so the model at least has a chance to see each sample, just at a lower rate.

esivonxay-cognitiv avatar Jul 02 '24 00:07 esivonxay-cognitiv

I recently encountered a similar issue while training a model with a batch normalization layer. Since batch normalization requires a batch size greater than 1 during training, the training process fails if a batch size of 1 is produced.

There may be a potential solution discussed here, where using drop_last in the DataLoader would cause PyTorch to automatically skip incomplete batches.

However, drop_last is not included in the StreamingDataLoader, and it's not sure if this omission is intentional. https://github.com/Lightning-AI/litdata/blob/c4c9117134e20c6b3f9ca7b071932475ab12da80/src/litdata/streaming/dataloader.py#L597-L605

jackcyc avatar Jul 10 '24 21:07 jackcyc

Hey @jackcyc @esivonxay-cognitiv,

Would any of you be willing to attempt a fix ? The CombinedDataset isn't well thought IMO and needs to be improved. It was designed for immense training where only a few epochs are made. Your use case is kinda of an edge case.

I think we should re-write it using PyTorch Lightning for inspiration: https://github.com/Lightning-AI/pytorch-lightning/blob/50af052b3129164e28efa8b9321d733311b7b459/src/lightning/pytorch/utilities/combined_loader.py#L222

tchaton avatar Jul 11 '24 07:07 tchaton

Hey Thomas, thanks for the followup.

I haven't looked at the PyTorch Lightning implementation exhaustively, but thanks for bringing it to my attention. I don't currently have the bandwidth for this, but I'll put it on my list of todos and revisit fixing/re-writing this.

esivonxay-cognitiv avatar Jul 16 '24 20:07 esivonxay-cognitiv

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

stale[bot] avatar Apr 16 '25 05:04 stale[bot]

Lets keep this. Worth checking at somepoint.

bhimrazy avatar Apr 17 '25 05:04 bhimrazy

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

stale[bot] avatar Jul 19 '25 06:07 stale[bot]