imbalanced-dataset-sampler
imbalanced-dataset-sampler copied to clipboard
ConcatDataset support
Thanks for the great work!
I try to combine two datasets by using "dataset = dataset1+dataset2", and it gives me such error: AttributeError: 'ConcatDataset' object has no attribute 'get_labels'
Is there any workaround?
Nvm, I found myself a workaround, pretty simple:
add two lines and one helper function:
def _get_labels(self, dataset):
if self.callback_get_label:
return self.callback_get_label(dataset)
elif isinstance(dataset, torchvision.datasets.MNIST):
return dataset.train_labels.tolist()
elif isinstance(dataset, torchvision.datasets.ImageFolder):
return [x[1] for x in dataset.imgs]
elif isinstance(dataset, torchvision.datasets.DatasetFolder):
return dataset.samples[:][1]
elif isinstance(dataset, torch.utils.data.Subset):
return dataset.dataset.imgs[:][1]
elif isinstance(dataset, torch.utils.data.ConcatDataset): # added. add before next `elif` because ConcatDataset belong to torch.utils.data.Dataset
return self._get_concat_labels(dataset) # added
elif isinstance(dataset, torch.utils.data.Dataset):
return dataset.get_labels()
else:
raise NotImplementedError
def _get_concat_labels(self,concatdataset): # added
dataset_list = concatdataset.datasets
concat_labels = []
for ds in dataset_list:
concat_labels.extend(ds.get_labels())
return concat_labels
Let me know if you have a better solution!