MLDataPattern.jl
MLDataPattern.jl copied to clipboard
Add stratified k-folds repartitioning strategy
I just implemented stratifiedkfolds
for my own work. It might be something that could live in this package. I'm not too familiar with the internals of MLDataPattern, so it might not be implemented in the best way:
using MLDataUtils
import MLDataUtils: default_obsdim
function stratifiedkfolds(f, data, k::Integer=5, obsdim=default_obsdim(data))
data_shuf = shuffleobs(data, obsdim=obsdim)
lm = labelmap(eachtarget(f, data_shuf, obsdim=obsdim))
val_indices = map(_->Int[], 1:k)
for (class, indices) = lm
# Validation indicies corresponding to the values in the labelmap dict
_, lm_val_indices = kfolds(length(indices), k)
for i = 1:k
# Map lm_val_indices indicies back to data indicies
append!(val_indices[i], lm[class][lm_val_indices[i]])
end
end
train_indices = map(t->setdiff(1:nobs(data; obsdim=obsdim), t), val_indices)
FoldsView(data_shuf, train_indices, val_indices, obsdim)
end
function stratifiedkfolds(data, k::Integer=5, obsdim=default_obsdim(data))
stratifiedkfolds(identity, data, k, obsdim)
end
Example usage:
X = rand(30, 1000)
y = [fill(0, 500); fill(1, 250); fill(2, 250)]
data = (X, y)
a = stratifiedkfolds(data, 5)
(Xtrain, ytrain), (Xval, yval) = stratifiedkfolds(data, 5)[2]
labelmap(ytrain)
labelmap(yval)
Hi @Evizero do you think stratifiedkfolds
is worth adding to the package?
Hi! sure I think it could be a cool addition. Sorry I didn't get to this. I am not really active anymore I am afraid.
I am unsure who maintains this package currently. Are you still active @oxinabox ?
I think the current status of packages under JuliaML is that only bug reports and regular maintenance are handled. For new features, it's usually not very responsive unless there're PRs coming in.
FYI, the flux community is implementing a DataLoader
-alternative recently, part of the reason for which is they think these packages are too delicate to touch & modify.
I am still around, but very busy. I won't make PRs to add new features I think this feature would be good to have. Be it here, or in @CarloLucibello's MLDataUtils2 sucessor