Flux.jl icon indicating copy to clipboard operation
Flux.jl copied to clipboard

is `Flux.huber_loss` type-unstable ?

Open filchristou opened this issue 1 year ago • 1 comments

It looks like Flux.huber_loss is type unstable when it comes to Zygote autodiff ?

using Flux, Zygote
import Statistics: mean

function internfunc_nobroad(m, x, y)
    modelvals = m(x)
    Flux.mse(modelvals, y)
end

function internfunc_nobroad_huberloss(m, x, y)
    modelvals = m(x)
    Flux.huber_loss(modelvals, y)
end

function wrapfunc(model, xdata, ydata, func)
    grad = let xdata=xdata, ydata=ydata
        Zygote.gradient(m -> func(m, xdata, ydata), model)
    end
    return grad
end

fc = Flux.Chain(Flux.Dense(5=>3, Flux.relu), Flux.Dense(3=>3, Flux.relu), Flux.Dense(3=>1))

fobs_ar = fill(5f0, 5, 10)
labels_ar = fill(2f0, 1, 10)
julia> @code_warntype wrapfunc(fc, fobs_ar, labels_ar, internfunc_nobroad)

image

julia> @code_warntype wrapfunc(fc, fobs_ar, labels_ar, internfunc_nobroad_huberloss)

image

filchristou avatar Jun 17 '24 14:06 filchristou

I don't know why this is unstable, the ways of Zygote are mysterious sometimes.

The loss broadcasts this function, which contains odd things: abs_error .< δ is strange as these are scalars. And ignore_derivatives is strange as Zygote shouldn't go here... the broadcasting uses ForwardDiff, as you can confirm with @show. But commenting out that line doesn't fix anything.

julia> @eval Flux.Losses @inline function _huber_metric(abs_error, δ)
           #TODO: remove ignore_derivatives when Zygote can handle this function with CuArrays
           temp = false # Zygote.ignore_derivatives(abs_error .<  δ)
           x = ofeltype(abs_error, 0.5)
           @show δ
           ((abs_error * abs_error) * temp) * x + δ * (abs_error - x * δ) * (1 - temp)
       end
_huber_metric (generic function with 7 methods)

julia> wrapfunc(fc, fobs_ar, labels_ar, internfunc_nobroad_huberloss)
δ = Dual{Nothing}(1.0,0.0,1.0)
δ = Dual{Nothing}(1.0,0.0,1.0)
δ = Dual{Nothing}(1.0,0.0,1.0)
δ = Dual{Nothing}(1.0,0.0,1.0)
δ = Dual{Nothing}(1.0,0.0,1.0)
δ = Dual{Nothing}(1.0,0.0,1.0)
δ = Dual{Nothing}(1.0,0.0,1.0)
δ = Dual{Nothing}(1.0,0.0,1.0)
δ = Dual{Nothing}(1.0,0.0,1.0)
δ = Dual{Nothing}(1.0,0.0,1.0)
((layers = ((weight = Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], bias = Float32[0.0, 0.0, 0.0], σ = nothing), (weight = Float32[0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0], bias = Float32[0.0, 0.0, 0.0], σ = nothing), (weight = Float32[0.0 0.0 0.0], bias = Float32[1.0000001], σ = nothing)),),)

mcabbott avatar Oct 30 '24 18:10 mcabbott