Zygote.jl
Zygote.jl copied to clipboard
Dictionary indexing failure inside closure and structs
julia> Zygote.gradient(x -> (() -> x[:y])(), Dict(:y => 0.4))
(nothing,)
the gradient w.r.t. the y
element of x
should be 1
.
This bug doesn't occur with the equivalent closure-free function
julia> Zygote.gradient(x -> x[:y], Dict(:y => 0.4))
(Dict{Any,Any}(:y => 1.0),)
and appears to be Dict
-specific since
julia> Zygote.gradient(x -> (() -> x.y)(), (y = 0.4,))
((y = 1.0,),)
This bug was introduced in 0.4.21 -- the correct result is obtained on 0.4.20. The bug persists on 0.4.22 and 0.5.
This is breaking for Stheno.jl.
@MikeInnes @CarloLucibello any thoughts on what might be causing this?
@DhairyaLGandhi do you have any thoughts on what might be causing this?
Here's another MWE. This one is a little more complex, because it matches a use case that I have.
Julia version = 1.5, Zygote version = 0.5.4
module GradsMVP
using Zygote
mutable struct Foo
store::Dict{Symbol, Float64}
score::Float64
end
function (f::Foo)(acc::Symbol, fn::Function, args...)
val = getindex(f.store, acc)
ret = fn(val)
f.score += ret
fn(args...)
end
function get_grads(store, acc, ret_grad, call, args...)
fn = (args, store) -> begin
f = Foo(store, 0.0)
ret = f(acc, call, args...)
(f.score, ret)
end
_, back = Zygote.pullback(fn, args, store)
arg_grads, store_grads = back((1.0, ret_grad))
return arg_grads, store_grads
end
function foo(a::Float64)
return a
end
ags, gs = get_grads(Dict(:x => 1.0), :x, 1.0, foo, 1.0)
println(ags)
println(gs) # = nothing
end # module
whereas this code works fine
module GradsMVP
using Zygote
mutable struct Foo
store::Float64
score::Float64
end
function (f::Foo)(acc::Symbol, fn::Function, args...)
val = f.store
ret = fn(val)
f.score += ret
fn(args...)
end
function get_grads(store, ret_grad, call, args...)
fn = (args, store) -> begin
f = Foo(store, 0.0)
ret = f(call, args...)
(f.score, ret)
end
_, back = Zygote.pullback(fn, args, store)
arg_grads, store_grads = back((1.0, ret_grad))
return arg_grads, store_grads
end
function foo(a::Float64)
return a
end
ags, gs = get_grads(1.0, 1.0, foo, 1.0)
println(ags)
println(gs) # = 1.0
end # module
To fix this MWE, it suffices to define the adjoint for getindex
:
Zygote.@adjoint getindex(d::Dict, acc) = getindex(d, acc), retgrad -> (retgrad, nothing)
I'm unsure if this will break something fundamental.
Edit: sorry, this is supposed to be retgrad
@DhairyaLGandhi it's not Zygote's version of getindex
- print outs of grad
show the correct gradients. This makes sense - that's obviously something which has been tested numerous times.
Something else is happening in the pipeline.
PS This is fixed on 0.4.20
as @willtebbutt says. I just checked with my own codebase.
Are you suggesting that the gradient is correctly calculated but isn't actually returned to the user properly?
What's happening is entirely unclear to me. Since it's Dict
-specific, and I could only produce the bug in conjunction with a closure 🤷
@DhairyaLGandhi when I print out accum
in the adjoint for getindex
- I see the correct gradients. But in the MWE above, the returned grad is nothing
.
@DhairyaLGandhi @willtebbutt any update on this?
This is highly frustrating to me. I can't update to the latest version of Zygote, so I can't use the latest version of IRTools, so I can't use the latest version of Flux, which means I can't use neural networks in my PPs.
I have no idea where this bug is occurring, but I'm motivated to find it and fix it - especially since it was fixed before in 0.4.20
, so it can't be hard to find again can it? Any ideas where to start looking?
Setup a PR. I don't know what I'm doing, so I don't know if this fix breaks many other things - please inform.
@willtebbutt @DhairyaLGandhi did this happened to get squashed in recent tags/PRs?
Hmmm I'm not sure. @DhairyaLGandhi is more likely to know.
This problem is still present
julia> d = Dict("x"=>rand(2))
Dict{String, Vector{Float64}} with 1 entry:
"x" => [0.626974, 0.519716]
julia> gradient(x -> sum(x["x"]), d) #OK
(Dict{Any, Any}("x" => 2-element Fill{Float64}: entries equal to 1.0),)
julia> nt = (; data=rand(2))
(data = [0.7536687262661153, 0.34819635465370324],)
julia> gradient(x -> sum(x.data), nt) #OK
((data = 2-element Fill{Float64}: entries equal to 1.0,),)
julia> ntd = (; data = Dict("x" => rand(2)))
(data = Dict("x" => [0.6917549230112572, 0.16463696222948876]),)
julia> gradient(x -> sum(x.data["x"]), ntd) #WRONG
(nothing,)
Came across this issue and I see all MWEs passing with https://github.com/FluxML/Zygote.jl/pull/1248. If anyone still has a larger example to test, could you confirm it passes as well? Otherwise I'll consider this issue fixed if nothing pops up after a few days.
closing as all examples are fixed. Will add tests