Type of vectors for aggregated losses overly restrictive?
It seems I can't use the loss functions in this package for either SparseVectors or Flux's tracked vectors. Is there a good reason why a general AbstractVector{<:Number} is not allowed?
julia> using LossFunctions
julia> using Flux
julia> v=Flux.param(rand(3))
Tracked 3-element Array{Float64,1}:
0.8718259298110083
0.9072480709509387
0.9563588391481992
julia> value(L2DistLoss(), v, v)
ERROR: TypeError: in typeassert, expected Array{Float64,1}, got TrackedArray{…,Array{Float64,1}}
Stacktrace:
[1] macro expansion at /Users/anthony/.julia/packages/LossFunctions/VQQPk/src/supervised/supervised.jl:263 [inlined]
[2] value at /Users/anthony/.julia/packages/LossFunctions/VQQPk/src/supervised/supervised.jl:260 [inlined]
[3] value(::LPDistLoss{2}, ::TrackedArray{…,Array{Float64,1}}, ::TrackedArray{…,Array{Float64,1}}) at /Users/anthony/.julia/packages/LossFunctions/VQQPk/src/supervised/supervised.jl:181
[4] top-level scope at none:0
julia> using SparseArrays
julia> v = sparse(rand(4))
4-element SparseVector{Float64,Int64} with 4 stored entries:
[1] = 0.0101042
[2] = 0.780219
[3] = 0.304598
[4] = 0.214392
julia> value(L2DistLoss(), v, v)
ERROR: TypeError: in typeassert, expected Array{Float64,1}, got SparseVector{Float64,Int64}
Stacktrace:
[1] macro expansion at /Users/anthony/.julia/packages/LossFunctions/VQQPk/src/supervised/supervised.jl:263 [inlined]
[2] value at /Users/anthony/.julia/packages/LossFunctions/VQQPk/src/supervised/supervised.jl:260 [inlined]
[3] value(::LPDistLoss{2}, ::SparseVector{Float64,Int64}, ::SparseVector{Float64,Int64}) at /Users/anthony/.julia/packages/LossFunctions/VQQPk/src/supervised/supervised.jl:181
[4] top-level scope at none:0
Did you try value.(...) (i.e. broadcast)?
If I remember correctly the reason the offending method has the ...
($($FUN)).(Ref(loss), target, output)::Array{S,$(max(N,M))}
... line, was that broadcast lead to failed return type inference. this was quite a while ago, so might be that we can just remove the ::Array{S,$(max(N,M))} bit
Thanks for that workaround, which works.
If someone can check the necessity of the typing that would be great.
@ablaom can you please double check that this issue is still valid in the latest master version? I refactored the entire package months ago to cleanup the type annotations, etc. If the issue persists, can you propose an improvement as a PR?
Very sorry @juliohm, but I just don't have the bandwidth right now to work on LossFunctions.jl. My post includes a MWE, which should be easy enough to check.
This shouldn't be an issue in the latest version. Now the loss function objects only accept scalars as input, and users are expected to use broadcasting to obtain the vector of losses, or the function mean and sum for aggregation.