Supporting samplers for creating balanced batches
Hey everyone,
We're working with a dataset where items belong to different clusters.
Example: item_1, cluster_A item_2, cluster_A item_3, cluster_B
I was wondering if there's a way to create batches that have a guaranteed mix of items from these clusters? For example, to ensure every single batch is made up of 50% items from cluster A and 50% from cluster B.
In PyTorch DataLoader for in-memory datasets, we handle this with a custom Sampler that you pass to the DataLoader (example code). The basic idea is to get lists of indices for each cluster, shuffle them, and then build the batches by picking one index from each cluster's list in a round-robin fashion. This approach can upsample the minority clusters to create balanced batches even when the original data is not.
Is there any way to do something like this? Any guidance would be much appreciated, and I'd be happy to help contribute if pointed in the right direction
Thanks
Hi @karinazad,
I think CombinedStreamingDataset might be relevant here, though it doesn't enforces per-batch weights.
If the goal is strict balance (e.g. 50% cluster A / 50% cluster B in every batch), one option is to stream each cluster separately and then zip their loaders:
batch_size = 256
ds_a = StreamingDataset("optimized_dataset_cluster_A")
ds_b = StreamingDataset("optimized_dataset_cluster_B")
dl_a = StreamingDataloader(ds_a, batch_size=batch_size//2)
dl_b = StreamingDataloader(ds_b, batch_size=batch_size//2)
for batch_a, batch_b in zip(dl_a, dl_b):
batch = shuffle(batch_a + batch_b)
# train
Haven’t been super active in the project lately, so treat this more as a direction than a guarantee.
Replacing `zip` with a simple `zip_upsample` that you described
def zip_upsample(*loaders):
iters = [iter(dl) for dl in loaders]
finished = {i: False for i in range(len(loaders))}
while not all(finished.values()):
batch = []
for i, it in enumerate(iters):
try:
batch.append(next(it))
except StopIteration:
iters[i] = iter(loaders[i]) # reset iterator
finished[i] = True
batch.append(next(iters[i]))
yield batch