lightning-bolts
lightning-bolts copied to clipboard
random_split looks that it cannot generate label-balanced sub-datasets
🐛 Bug
The pytorch-bultin function torch.utils.data.random_split
is used in multiple DataModule
s. However, this function implementation is not correct, and it cannot generate label-balanced sub-datasets.
To Reproduce
Steps to reproduce the behavior:
run the following code:
from pl_bolts.datamodules import CIFAR10DataModule
dm = CIFAR10DataModule('/localhome/fair/Dataset/cifar10')
stat = [0 for i in range(10)]
for batch in dm.train_dataloader():
inputs, targets = batch
for b in range(targets.size()[0]):
stat[targets[b].item()] += 1
stat
and it will output:
[4512, 4486, 4466, 4529, 4528, 4485, 4493, 4499, 4495, 4499]
Expected behavior
We want a label-balanced output. That is to say, the sample label distribution of the split sub-datasets should have the same proportion of the original dataset.
[4500, 4500, 4500, 4500, 4500, 4500, 4500, 4500, 4500]
Environment
- PyTorch Version: 1.6
- OS: Ubuntu 18.04 LTS
- How you installed PyTorch (
conda
,pip
, source): conda - Build command you used (if compiling from source): N/A
- Python version: 3.8
- CUDA/cuDNN version: 10.2
- GPU models and configuration: 2080Ti
- Any other relevant information:
Additional context
The bug of random_split has been discussed many times in the community, such as Link. It may produce an inconspicuous skewed val_dataset especially when the dataset is not large enough.
A better implementation is sklearn.model_selection.train_test_split
, here is an example code:
import numpy as np
from torch.utils.data import Subset, DataLoader
import torchvision as tv
from sklearn.model_selection import train_test_split
def __balance_val_split(dataset, val_split=0.):
targets = np.array(dataset.targets)
train_indices, val_indices = train_test_split(
np.arange(targets.shape[0]),
test_size=val_split,
stratify=targets
)
train_dataset = Subset(dataset, indices=train_indices)
val_dataset = Subset(dataset, indices=val_indices)
return train_dataset, val_dataset
Hi @zhutmost, this looks great but it also means we'll have a hard dependency on sklearn... I'm looking into how we can implement a stratified split and will keep you updated!
the case why I have moved to PT split was to drop dependency on sklearn
as this seems to be the only usage... :]
so shall we use it again or just implement the split ourselves?
I've started implementing a generalized split function for PL so I'll create a PR sometime this weekend for you to review if it makes sense? :)
the case why I have moved to PT split was to drop dependency on
sklearn
as this seems to be the only usage... :] so shall we use it again or just implement the split ourselves?
Thanks a lot for your great pl_blots. And I find some DataModules such as Imagenet_dataset also depends on sklearn. So looks that it is not a serious problem.. maybe?