Flux.jl
Flux.jl copied to clipboard
is `Flux.huber_loss` type-unstable ?
It looks like Flux.huber_loss
is type unstable when it comes to Zygote autodiff ?
using Flux, Zygote
import Statistics: mean
function internfunc_nobroad(m, x, y)
modelvals = m(x)
Flux.mse(modelvals, y)
end
function internfunc_nobroad_huberloss(m, x, y)
modelvals = m(x)
Flux.huber_loss(modelvals, y)
end
function wrapfunc(model, xdata, ydata, func)
grad = let xdata=xdata, ydata=ydata
Zygote.gradient(m -> func(m, xdata, ydata), model)
end
return grad
end
fc = Flux.Chain(Flux.Dense(5=>3, Flux.relu), Flux.Dense(3=>3, Flux.relu), Flux.Dense(3=>1))
fobs_ar = fill(5f0, 5, 10)
labels_ar = fill(2f0, 1, 10)
julia> @code_warntype wrapfunc(fc, fobs_ar, labels_ar, internfunc_nobroad)
julia> @code_warntype wrapfunc(fc, fobs_ar, labels_ar, internfunc_nobroad_huberloss)