Flux.jl
Flux.jl copied to clipboard
is `Flux.huber_loss` type-unstable ?
It looks like Flux.huber_loss is type unstable when it comes to Zygote autodiff ?
using Flux, Zygote
import Statistics: mean
function internfunc_nobroad(m, x, y)
modelvals = m(x)
Flux.mse(modelvals, y)
end
function internfunc_nobroad_huberloss(m, x, y)
modelvals = m(x)
Flux.huber_loss(modelvals, y)
end
function wrapfunc(model, xdata, ydata, func)
grad = let xdata=xdata, ydata=ydata
Zygote.gradient(m -> func(m, xdata, ydata), model)
end
return grad
end
fc = Flux.Chain(Flux.Dense(5=>3, Flux.relu), Flux.Dense(3=>3, Flux.relu), Flux.Dense(3=>1))
fobs_ar = fill(5f0, 5, 10)
labels_ar = fill(2f0, 1, 10)
julia> @code_warntype wrapfunc(fc, fobs_ar, labels_ar, internfunc_nobroad)
julia> @code_warntype wrapfunc(fc, fobs_ar, labels_ar, internfunc_nobroad_huberloss)
I don't know why this is unstable, the ways of Zygote are mysterious sometimes.
The loss broadcasts this function, which contains odd things: abs_error .< δ is strange as these are scalars. And ignore_derivatives is strange as Zygote shouldn't go here... the broadcasting uses ForwardDiff, as you can confirm with @show. But commenting out that line doesn't fix anything.
julia> @eval Flux.Losses @inline function _huber_metric(abs_error, δ)
#TODO: remove ignore_derivatives when Zygote can handle this function with CuArrays
temp = false # Zygote.ignore_derivatives(abs_error .< δ)
x = ofeltype(abs_error, 0.5)
@show δ
((abs_error * abs_error) * temp) * x + δ * (abs_error - x * δ) * (1 - temp)
end
_huber_metric (generic function with 7 methods)
julia> wrapfunc(fc, fobs_ar, labels_ar, internfunc_nobroad_huberloss)
δ = Dual{Nothing}(1.0,0.0,1.0)
δ = Dual{Nothing}(1.0,0.0,1.0)
δ = Dual{Nothing}(1.0,0.0,1.0)
δ = Dual{Nothing}(1.0,0.0,1.0)
δ = Dual{Nothing}(1.0,0.0,1.0)
δ = Dual{Nothing}(1.0,0.0,1.0)
δ = Dual{Nothing}(1.0,0.0,1.0)
δ = Dual{Nothing}(1.0,0.0,1.0)
δ = Dual{Nothing}(1.0,0.0,1.0)
δ = Dual{Nothing}(1.0,0.0,1.0)
((layers = ((weight = Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], bias = Float32[0.0, 0.0, 0.0], σ = nothing), (weight = Float32[0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0], bias = Float32[0.0, 0.0, 0.0], σ = nothing), (weight = Float32[0.0 0.0 0.0], bias = Float32[1.0000001], σ = nothing)),),)