litdata icon indicating copy to clipboard operation
litdata copied to clipboard

Supporting samplers for creating balanced batches

Open karinazad opened this issue 3 months ago • 1 comments

Hey everyone,

We're working with a dataset where items belong to different clusters.

Example: item_1, cluster_A item_2, cluster_A item_3, cluster_B

I was wondering if there's a way to create batches that have a guaranteed mix of items from these clusters? For example, to ensure every single batch is made up of 50% items from cluster A and 50% from cluster B.

In PyTorch DataLoader for in-memory datasets, we handle this with a custom Sampler that you pass to the DataLoader (example code). The basic idea is to get lists of indices for each cluster, shuffle them, and then build the batches by picking one index from each cluster's list in a round-robin fashion. This approach can upsample the minority clusters to create balanced batches even when the original data is not.

Is there any way to do something like this? Any guidance would be much appreciated, and I'd be happy to help contribute if pointed in the right direction

Thanks

karinazad avatar Sep 22 '25 18:09 karinazad

Hi @karinazad,

I think CombinedStreamingDataset might be relevant here, though it doesn't enforces per-batch weights.

If the goal is strict balance (e.g. 50% cluster A / 50% cluster B in every batch), one option is to stream each cluster separately and then zip their loaders:

batch_size = 256

ds_a = StreamingDataset("optimized_dataset_cluster_A")
ds_b = StreamingDataset("optimized_dataset_cluster_B")

dl_a = StreamingDataloader(ds_a, batch_size=batch_size//2)
dl_b = StreamingDataloader(ds_b, batch_size=batch_size//2)

for batch_a, batch_b in zip(dl_a, dl_b):
    batch = shuffle(batch_a + batch_b)
    # train

Haven’t been super active in the project lately, so treat this more as a direction than a guarantee.


Replacing `zip` with a simple `zip_upsample` that you described
def zip_upsample(*loaders):
    iters = [iter(dl) for dl in loaders]
    finished = {i: False for i in range(len(loaders))}
    while not all(finished.values()):
        batch = []
        for i, it in enumerate(iters):
            try:
                batch.append(next(it))
            except StopIteration:
                iters[i] = iter(loaders[i])   # reset iterator
                finished[i] = True
                batch.append(next(iters[i]))
        yield batch

deependujha avatar Sep 23 '25 05:09 deependujha