Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

Dictionary indexing failure inside closure and structs

Open willtebbutt opened this issue 4 years ago • 14 comments

julia> Zygote.gradient(x -> (() -> x[:y])(), Dict(:y => 0.4))
(nothing,)

the gradient w.r.t. the y element of x should be 1.

This bug doesn't occur with the equivalent closure-free function

julia> Zygote.gradient(x -> x[:y], Dict(:y => 0.4))
(Dict{Any,Any}(:y => 1.0),)

and appears to be Dict-specific since

julia> Zygote.gradient(x -> (() -> x.y)(), (y = 0.4,))
((y = 1.0,),)

This bug was introduced in 0.4.21 -- the correct result is obtained on 0.4.20. The bug persists on 0.4.22 and 0.5.

This is breaking for Stheno.jl.

@MikeInnes @CarloLucibello any thoughts on what might be causing this?

willtebbutt avatar Jun 27 '20 11:06 willtebbutt

@DhairyaLGandhi do you have any thoughts on what might be causing this?

willtebbutt avatar Jul 25 '20 16:07 willtebbutt

Here's another MWE. This one is a little more complex, because it matches a use case that I have.

Julia version = 1.5, Zygote version = 0.5.4

module GradsMVP

using Zygote

mutable struct Foo
    store::Dict{Symbol, Float64}
    score::Float64
end

function (f::Foo)(acc::Symbol, fn::Function, args...)
    val = getindex(f.store, acc)
    ret = fn(val)
    f.score += ret
    fn(args...)
end

function get_grads(store, acc, ret_grad, call, args...)
    fn = (args, store) -> begin
        f = Foo(store, 0.0)
        ret = f(acc, call, args...)
        (f.score, ret)
    end
    _, back = Zygote.pullback(fn, args, store)
    arg_grads, store_grads = back((1.0, ret_grad))
    return arg_grads, store_grads
end

function foo(a::Float64)
    return a
end

ags, gs = get_grads(Dict(:x => 1.0), :x, 1.0, foo, 1.0)
println(ags)
println(gs) # = nothing

end # module

whereas this code works fine

module GradsMVP

using Zygote

mutable struct Foo
    store::Float64
    score::Float64
end

function (f::Foo)(acc::Symbol, fn::Function, args...)
    val = f.store
    ret = fn(val)
    f.score += ret
    fn(args...)
end

function get_grads(store, ret_grad, call, args...)
    fn = (args, store) -> begin
        f = Foo(store, 0.0)
        ret = f(call, args...)
        (f.score, ret)
    end
    _, back = Zygote.pullback(fn, args, store)
    arg_grads, store_grads = back((1.0, ret_grad))
    return arg_grads, store_grads
end

function foo(a::Float64)
    return a
end

ags, gs = get_grads(1.0, 1.0, foo, 1.0)
println(ags)
println(gs) # = 1.0

end # module

femtomc avatar Aug 06 '20 14:08 femtomc

To fix this MWE, it suffices to define the adjoint for getindex:

Zygote.@adjoint getindex(d::Dict, acc) = getindex(d, acc), retgrad -> (retgrad, nothing)

I'm unsure if this will break something fundamental.

Edit: sorry, this is supposed to be retgrad

femtomc avatar Aug 06 '20 14:08 femtomc

@DhairyaLGandhi it's not Zygote's version of getindex - print outs of grad show the correct gradients. This makes sense - that's obviously something which has been tested numerous times.

Something else is happening in the pipeline.

femtomc avatar Aug 06 '20 15:08 femtomc

PS This is fixed on 0.4.20 as @willtebbutt says. I just checked with my own codebase.

femtomc avatar Aug 06 '20 15:08 femtomc

Are you suggesting that the gradient is correctly calculated but isn't actually returned to the user properly?

DhairyaLGandhi avatar Aug 11 '20 04:08 DhairyaLGandhi

What's happening is entirely unclear to me. Since it's Dict-specific, and I could only produce the bug in conjunction with a closure 🤷

willtebbutt avatar Aug 11 '20 08:08 willtebbutt

@DhairyaLGandhi when I print out accum in the adjoint for getindex - I see the correct gradients. But in the MWE above, the returned grad is nothing.

femtomc avatar Aug 11 '20 10:08 femtomc

@DhairyaLGandhi @willtebbutt any update on this?

This is highly frustrating to me. I can't update to the latest version of Zygote, so I can't use the latest version of IRTools, so I can't use the latest version of Flux, which means I can't use neural networks in my PPs.

I have no idea where this bug is occurring, but I'm motivated to find it and fix it - especially since it was fixed before in 0.4.20, so it can't be hard to find again can it? Any ideas where to start looking?

femtomc avatar Aug 15 '20 00:08 femtomc

Setup a PR. I don't know what I'm doing, so I don't know if this fix breaks many other things - please inform.

femtomc avatar Aug 15 '20 02:08 femtomc

@willtebbutt @DhairyaLGandhi did this happened to get squashed in recent tags/PRs?

femtomc avatar Sep 21 '20 14:09 femtomc

Hmmm I'm not sure. @DhairyaLGandhi is more likely to know.

willtebbutt avatar Sep 21 '20 14:09 willtebbutt

This problem is still present

julia> d = Dict("x"=>rand(2))
Dict{String, Vector{Float64}} with 1 entry:
  "x" => [0.626974, 0.519716]

julia> gradient(x -> sum(x["x"]), d)  #OK
(Dict{Any, Any}("x" => 2-element Fill{Float64}: entries equal to 1.0),)

julia> nt = (; data=rand(2))
(data = [0.7536687262661153, 0.34819635465370324],)

julia> gradient(x -> sum(x.data), nt)  #OK
((data = 2-element Fill{Float64}: entries equal to 1.0,),)

julia> ntd = (; data = Dict("x" => rand(2)))
(data = Dict("x" => [0.6917549230112572, 0.16463696222948876]),)

julia> gradient(x -> sum(x.data["x"]), ntd) #WRONG
(nothing,)

CarloLucibello avatar Jul 21 '21 09:07 CarloLucibello

Came across this issue and I see all MWEs passing with https://github.com/FluxML/Zygote.jl/pull/1248. If anyone still has a larger example to test, could you confirm it passes as well? Otherwise I'll consider this issue fixed if nothing pops up after a few days.

ToucheSir avatar Aug 09 '22 05:08 ToucheSir

closing as all examples are fixed. Will add tests

CarloLucibello avatar Nov 23 '22 12:11 CarloLucibello