ChainRules.jl icon indicating copy to clipboard operation
ChainRules.jl copied to clipboard

Gradients from keyword arguments dropped

Open mcabbott opened this issue 3 years ago • 0 comments

The rules for accumulate and foldl don't compute anything for the init keyword. This can lead to silently wrong gradients, which is bad. Maybe this is a bigger problem than I realised.

It would be better to return @not_implemented. In fact, it's possible that all keywords everywhere should be that, or something like it, if possible. And even better to return the true answer, which IIRC these functions do know. Can this be done?

Example from here: https://discourse.julialang.org/t/how-to-efficiently-build-ad-compatible-matrices-line-by-line/74632/17

function f(K, xi, d)
    x = xi
    for i = 2:d
        x = hcat(x, K*x[:, i-1])
    end
    return x
end

K = rand(3,3)
xi = rand(3,1)
f(K, xi, 50)

function f2(K, xi, d::Int)
    xs = accumulate(1:d-1; init=xi) do x, i
        K * x
    end
    hcat(xi, reduce(hcat, xs))
end

Gives this, Fill(1.0, 3, 1) is from hcat alone:

julia> using Zygote

julia> gradient(sum∘f, K, xi, 10)
([63.45016309970954 50.40609159573776 101.36588271461751; 23.572874731387856 18.379315265377535 35.224999619160954; 31.033286457367566 24.176359057416636 46.03455941092244], [48.455853839178204; 14.765466919614408; 18.845362109436827;;], nothing)

julia> gradient(sum∘f2, K, xi, 10) # NB the gradient for init=xi is missing!
([63.45016309970953 50.40609159573775 101.3658827146175; 23.572874731387852 18.379315265377535 35.22499961916095; 31.033286457367552 24.17635905741663 46.03455941092242], Fill(1.0, 3, 1), nothing)

mcabbott avatar Jan 15 '22 02:01 mcabbott