webdataset
webdataset copied to clipboard
How to implement batch sampler on webdataset?
Hi everyone, hope you are doing well wanted to ask a technical question regarding webdataset. I was trying to implement a costum batch sampler function. The issue is the following, using dataloader from torch we can do the following, having a custom batch sampler function.
class ExpSampler:
def __init__(self, dataset, random: bool = True):
self.dataset = dataset
self.exps = self.dataset.df["experiment"].unique()
self.random = random
def __iter__(self):
indexes = np.arange(len(self.exps))
if self.random:
np.random.shuffle(indexes)
for exp_id in indexes:
exp = self.exps[exp_id]
mask = self.dataset.df["experiment"] == exp
all_wells = np.array(self.dataset.df[mask].index)
yield all_wells
def __len__(self):
return len(self.exps)
then pass this sampler to the batch_sampler
train_dl = DataLoader(
train_data,
num_workers=12,
pin_memory=True,
batch_sampler=batch_sampler,
collate_fn=collate_fn,
)
Using webdataset : we first create the dataset:
dataset = webdataset.WebDataset(
file_names,
resampled=False,
nodesplitter=webdataset.split_by_node,
shardshuffle=False,
empty_check=False,
handler=log_and_continue,
)
and then the loader:
loader = webdataset.WebLoader(
dataset.batched(16, collation_fn=ban_full_lib_collate_fn),
num_workers=num_workers,
persistent_workers=False,
pin_memory=True,
)
Searching the documentation of webdataset cannot find a way to create a custom sampler, this is something that should be done in any bioML project to prevent experimental batch effects (each batch gets data from one experimental condition only)