imbalanced-dataset-sampler icon indicating copy to clipboard operation
imbalanced-dataset-sampler copied to clipboard

ConcatDataset support

Open jasonbian97 opened this issue 3 years ago • 1 comments

Thanks for the great work!

I try to combine two datasets by using "dataset = dataset1+dataset2", and it gives me such error: AttributeError: 'ConcatDataset' object has no attribute 'get_labels'

Is there any workaround?

jasonbian97 avatar Aug 05 '21 22:08 jasonbian97

Nvm, I found myself a workaround, pretty simple:

add two lines and one helper function:

def _get_labels(self, dataset):
        if self.callback_get_label:
            return self.callback_get_label(dataset)
        elif isinstance(dataset, torchvision.datasets.MNIST):
            return dataset.train_labels.tolist()
        elif isinstance(dataset, torchvision.datasets.ImageFolder):
            return [x[1] for x in dataset.imgs]
        elif isinstance(dataset, torchvision.datasets.DatasetFolder):
            return dataset.samples[:][1]
        elif isinstance(dataset, torch.utils.data.Subset):
            return dataset.dataset.imgs[:][1]
        elif isinstance(dataset, torch.utils.data.ConcatDataset): # added. add before next `elif` because ConcatDataset belong to torch.utils.data.Dataset
            return self._get_concat_labels(dataset) # added
        elif isinstance(dataset, torch.utils.data.Dataset):
            return dataset.get_labels()
        else:
            raise NotImplementedError

    def _get_concat_labels(self,concatdataset): # added
        dataset_list = concatdataset.datasets
        concat_labels = []
        for ds in dataset_list:
            concat_labels.extend(ds.get_labels())
        return concat_labels

Let me know if you have a better solution!

jasonbian97 avatar Aug 05 '21 22:08 jasonbian97