MLUtils.jl icon indicating copy to clipboard operation
MLUtils.jl copied to clipboard

PR welcome on stratified K-folds?

Open SimonEnsemble opened this issue 3 years ago • 4 comments

I can make a first-round PR, then willing to make whatever changes necessary, if anyone is willing to coach me through it/ look at my PR. thx.

SimonEnsemble avatar Feb 22 '22 20:02 SimonEnsemble

Yes, this would be a nice addition. You can probably adapt the MLDataPattern.jl code as a starting point. I would try and start with the simplest implementation first though, since part of the goal with this package is to cut down on the complexity of the MLDataPattern.jl codebase.

darsnack avatar Feb 22 '22 20:02 darsnack

feel like we should start with the split observations, stratified version, first.

this is the simplest/clearest implementation I could come up with that works for (i) target vector y with any number of classes and (ii) an at that could be a tuple. is there a cleaner way to do this? also, these are not views like in the rest of the package...

will need some coaching... can write a few tests for this though.

function splitobs_stratified(;at, y::Array, shuffle::Bool=true)
	n_splits = length(at) + 1
	the_splits = [Int[] for s = 1:n_splits]
	for label in unique(y)
		ids_this_label = filter(i -> y[i] == label, 1:length(y))
		if shuffle
			ids_this_label = shuffleobs(ids_this_label)
		end
		split_this_label = splitobs(ids_this_label, at=at)
		for s = 1:n_splits
			the_splits[s] = vcat(the_splits[s], split_this_label[s])
		end
	end
	return the_splits
end
targets = vcat([-1 for i = 1:10], [1 for i = 1:100])
splits = splitobs_stratified(at=(0.2, 0.5), y=targets, shuffle=true)
for s in splits
       println(sum(targets[s] .== 1) / sum(targets[s] .== -1)) # 10.0 woo!
end

SimonEnsemble avatar Feb 23 '22 07:02 SimonEnsemble

@SimonEnsemble your approach seems fine, you should open a PR, we can discuss there the details

CarloLucibello avatar Mar 10 '22 09:03 CarloLucibello