lightning-bolts icon indicating copy to clipboard operation
lightning-bolts copied to clipboard

random_split looks that it cannot generate label-balanced sub-datasets

Open zhutmost opened this issue 4 years ago • 5 comments

🐛 Bug

The pytorch-bultin function torch.utils.data.random_split is used in multiple DataModules. However, this function implementation is not correct, and it cannot generate label-balanced sub-datasets.

To Reproduce

Steps to reproduce the behavior:

run the following code:

from pl_bolts.datamodules import CIFAR10DataModule
dm = CIFAR10DataModule('/localhome/fair/Dataset/cifar10')
stat = [0 for i in range(10)]
for batch in dm.train_dataloader():
    inputs, targets = batch
    for b in range(targets.size()[0]):
        stat[targets[b].item()] += 1
stat

and it will output:

[4512, 4486, 4466, 4529, 4528, 4485, 4493, 4499, 4495, 4499]

Expected behavior

We want a label-balanced output. That is to say, the sample label distribution of the split sub-datasets should have the same proportion of the original dataset.

[4500, 4500, 4500, 4500, 4500,  4500, 4500, 4500, 4500]

Environment

  • PyTorch Version: 1.6
  • OS: Ubuntu 18.04 LTS
  • How you installed PyTorch (conda, pip, source): conda
  • Build command you used (if compiling from source): N/A
  • Python version: 3.8
  • CUDA/cuDNN version: 10.2
  • GPU models and configuration: 2080Ti
  • Any other relevant information:

Additional context

zhutmost avatar Oct 27 '20 11:10 zhutmost

The bug of random_split has been discussed many times in the community, such as Link. It may produce an inconspicuous skewed val_dataset especially when the dataset is not large enough.

A better implementation is sklearn.model_selection.train_test_split, here is an example code:

import numpy as np
from torch.utils.data import Subset, DataLoader
import torchvision as tv
from sklearn.model_selection import train_test_split

def __balance_val_split(dataset, val_split=0.):
    targets = np.array(dataset.targets)
    train_indices, val_indices = train_test_split(
        np.arange(targets.shape[0]),
        test_size=val_split,
        stratify=targets
    )
    train_dataset = Subset(dataset, indices=train_indices)
    val_dataset = Subset(dataset, indices=val_indices)
    return train_dataset, val_dataset

zhutmost avatar Oct 27 '20 11:10 zhutmost

Hi @zhutmost, this looks great but it also means we'll have a hard dependency on sklearn... I'm looking into how we can implement a stratified split and will keep you updated!

annikabrundyn avatar Nov 18 '20 19:11 annikabrundyn

the case why I have moved to PT split was to drop dependency on sklearn as this seems to be the only usage... :] so shall we use it again or just implement the split ourselves?

Borda avatar Nov 18 '20 20:11 Borda

I've started implementing a generalized split function for PL so I'll create a PR sometime this weekend for you to review if it makes sense? :)

annikabrundyn avatar Nov 20 '20 14:11 annikabrundyn

the case why I have moved to PT split was to drop dependency on sklearn as this seems to be the only usage... :] so shall we use it again or just implement the split ourselves?

Thanks a lot for your great pl_blots. And I find some DataModules such as Imagenet_dataset also depends on sklearn. So looks that it is not a serious problem.. maybe?

zhutmost avatar Nov 26 '20 12:11 zhutmost