Flux.jl
Flux.jl copied to clipboard
Weights and positive weights for labels in Flux losses
Motivation and description
Hello folks, I searched in Flux issues and documentations but couldn't find anything about having (or a reason for not having) weights for labels in the native losses functions.
Is it too much boilerplate for the Flux losses? Looking at their code they are generally very clean and concise.
If the recommendation is to build your own loss in case you need weighted labels, shouldn't we have this maybe more explicitly documented with one or more examples?
Possible Implementation
One can easily implement this with something like:
function binarycrossentropy(ŷ, y; agg = mean, ϵ = epseltype(ŷ), weights=1)
_check_sizes(ŷ, y, weitghs) # for checking if the weights are also valid
agg(@.(weights'*(-xlogy(y, ŷ + ϵ) - xlogy(1 - y, 1 - ŷ + ϵ))))
end
A simple example:
julia> x = [[0.8] [0.1]]
1×2 Matrix{Float64}:
0.8 0.1
julia> y = [[1.0] [0.0]]
1×2 Matrix{Float64}:
1.0 0.0
julia> binarycrossentropy(x, y)
0.16425203348601775
julia> binarycrossentropy(x, y, weights=[0.75, 1.25])
0.1495291540289698
You can do that by passing a custom aggregation function that includes the weights,it is briefly mentioned in the docs, something like
loss(ŷ, y, agg = x -> mean(weights .* x)) # weighted mean
Yes, this would work for the weights in all labels, but if you want it for the positive labels it wouldn't work, right?
Something like this for LogitBCE:
function logitbinarycrossentropy(ŷ, y; agg = mean, ϵ = epseltype(ŷ), pos_weights=1)
_check_sizes(ŷ, y)
# or _check_sizes(ŷ, y, weights) for checking if the weights are valid
agg(@.(-pos_weights*xlogy(y, σ(ŷ) + ϵ) - xlogy(1 - y, 1 - σ(ŷ) + ϵ)))
end
PS: I know the implementation of logitbinarycrossentropy is not like this, but just for a quickly visualization of the idea.
For that case you can do
weights = @. y * pos_weight + (1 - y) * neg_weight
logitbinarycrossentropy(ŷ, y; agg = x -> mean(weights .* x))
We can add this example to the docs.
Nice one! I think it would be good to have it there. From my experience, coming from other Python frameworks they have it built-in, so I think showing this more "Fluxian" way of doing a nice addition.