Flux.jl icon indicating copy to clipboard operation
Flux.jl copied to clipboard

Weights and positive weights for labels in Flux losses

Open gabrielpreviato opened this issue 2 years ago • 5 comments

Motivation and description

Hello folks, I searched in Flux issues and documentations but couldn't find anything about having (or a reason for not having) weights for labels in the native losses functions.

Is it too much boilerplate for the Flux losses? Looking at their code they are generally very clean and concise.

If the recommendation is to build your own loss in case you need weighted labels, shouldn't we have this maybe more explicitly documented with one or more examples?

Possible Implementation

One can easily implement this with something like:

function binarycrossentropy(ŷ, y; agg = mean, ϵ = epseltype(ŷ), weights=1)
    _check_sizes(ŷ, y, weitghs) # for checking if the weights are also valid
    agg(@.(weights'*(-xlogy(y, ŷ + ϵ) - xlogy(1 - y, 1 - ŷ + ϵ))))
end

A simple example:

julia> x = [[0.8] [0.1]]
1×2 Matrix{Float64}:
 0.8  0.1

julia> y = [[1.0] [0.0]]
1×2 Matrix{Float64}:
 1.0  0.0

julia> binarycrossentropy(x, y)
0.16425203348601775

julia> binarycrossentropy(x, y, weights=[0.75, 1.25])
0.1495291540289698

gabrielpreviato avatar Jan 06 '23 21:01 gabrielpreviato

You can do that by passing a custom aggregation function that includes the weights,it is briefly mentioned in the docs, something like

loss(ŷ, y, agg = x -> mean(weights .* x))    # weighted mean

CarloLucibello avatar Jan 06 '23 22:01 CarloLucibello

Yes, this would work for the weights in all labels, but if you want it for the positive labels it wouldn't work, right?

Something like this for LogitBCE:

function logitbinarycrossentropy(ŷ, y; agg = mean, ϵ = epseltype(ŷ), pos_weights=1)
    _check_sizes(ŷ, y)
    # or _check_sizes(ŷ, y, weights) for checking if the weights are valid
    agg(@.(-pos_weights*xlogy(y, σ(ŷ) + ϵ) - xlogy(1 - y, 1 - σ(ŷ) + ϵ)))
end

PS: I know the implementation of logitbinarycrossentropy is not like this, but just for a quickly visualization of the idea.

gabrielpreviato avatar Jan 06 '23 23:01 gabrielpreviato

For that case you can do

weights = @. y * pos_weight + (1 - y) * neg_weight
logitbinarycrossentropy(ŷ, y; agg = x -> mean(weights .* x))

We can add this example to the docs.

CarloLucibello avatar Jan 06 '23 23:01 CarloLucibello

Nice one! I think it would be good to have it there. From my experience, coming from other Python frameworks they have it built-in, so I think showing this more "Fluxian" way of doing a nice addition.

gabrielpreviato avatar Jan 07 '23 00:01 gabrielpreviato

One issue with the design of a weights API is that there's ambiguity between per-sample and per-class/feature/target weights. PyTorch solves this by...not solving it and having a weight param that means different things across different losses (if indeed they expose such a parameter at all).

ToucheSir avatar Jan 07 '23 00:01 ToucheSir