mmpretrain
mmpretrain copied to clipboard
Is there a way to use torch's IterableDataset?
Describe the question you meet
Background information
I have a very imbalanced dataset for binary classification and I want to dynamically generate artificial samples based on some tagret sample distribution (e.g.: imagine I want to produce batches with approximately 50% positive samples, 25% real negtive and 25% artificial negative). After looking at the classes available, I couldn't find a way to implement this behaviour, so I thought of implementing a custom IterableDataset that produces the samples according to the desired distribution.
All Dataset classes I see, however, extend torch.utils.data.Dataset
, so the question: can torch.utils.data.IterableDataset
be used at all?
It is recommended to use https://mmclassification.readthedocs.io/en/latest/tutorials/new_dataset.html#class-balanced-dataset.
If you want to use IterableDataset, You will implement a IterBaseDataset like https://github.com/open-mmlab/mmclassification/blob/master/mmcls/datasets/base_dataset.py
@Ezra-Yu my problem with ClassBalancedDataset
is that it gives no control on what the actual target distribution for classes will be, it just balances the classes so that they're approximately uniform, while I want a target distribution with (1/2, 1/4, 1/4) samples.
I also looked into the samplers and implemented a distributed version of torch's WeightedRandomSampler
, though it's untested and I don't knowif I did that right.
Looks the WeightRandomSampler
is a good feature and maybe you can create a PR to add it to mmcls