MLDataPattern.jl
MLDataPattern.jl copied to clipboard
BalancedBatches for labeled data
trafficstars
Would be nice to have a new data iterator that samples the given data in such a way, that each iteration a batch is returned that contains an equal amount of observations from each class (no matter the class distribution)
X = rand(2, 6) # some features
y = [:a, :a, :a, :a, :b, :b]
for (xbatch, ybatch) in BalancedBatches((X, y), size = 2, count = 10)
# ybatch is always either [:a, :b] or [:b, :a]
end
- see
RandomBatchesfor an example of a batch iterator https://github.com/JuliaML/MLDataPattern.jl/blob/master/src/dataiterator.jl - see
stratifiedobsfor an example of how to compute the indices of the observations that belong to each class (look howlabelmapis used)
Agreed. I believe that now when RandomObs/RandomBatches are available, it makes sense to have BalancedBatches in addition to BalancedObs as well.