mmpretrain icon indicating copy to clipboard operation
mmpretrain copied to clipboard

Is there a way to use torch's IterableDataset?

Open GPhilo opened this issue 2 years ago • 2 comments

Describe the question you meet

Background information

I have a very imbalanced dataset for binary classification and I want to dynamically generate artificial samples based on some tagret sample distribution (e.g.: imagine I want to produce batches with approximately 50% positive samples, 25% real negtive and 25% artificial negative). After looking at the classes available, I couldn't find a way to implement this behaviour, so I thought of implementing a custom IterableDataset that produces the samples according to the desired distribution.

All Dataset classes I see, however, extend torch.utils.data.Dataset, so the question: can torch.utils.data.IterableDataset be used at all?

GPhilo avatar Jul 08 '22 11:07 GPhilo

It is recommended to use https://mmclassification.readthedocs.io/en/latest/tutorials/new_dataset.html#class-balanced-dataset.

If you want to use IterableDataset, You will implement a IterBaseDataset like https://github.com/open-mmlab/mmclassification/blob/master/mmcls/datasets/base_dataset.py

Ezra-Yu avatar Jul 11 '22 02:07 Ezra-Yu

@Ezra-Yu my problem with ClassBalancedDataset is that it gives no control on what the actual target distribution for classes will be, it just balances the classes so that they're approximately uniform, while I want a target distribution with (1/2, 1/4, 1/4) samples. I also looked into the samplers and implemented a distributed version of torch's WeightedRandomSampler, though it's untested and I don't knowif I did that right.

GPhilo avatar Jul 12 '22 08:07 GPhilo

Looks the WeightRandomSampler is a good feature and maybe you can create a PR to add it to mmcls

mzr1996 avatar Oct 13 '22 03:10 mzr1996