Zygote.jl
Zygote.jl copied to clipboard
gradients with aliased variables
I was trying to figure out how to properly handle and update Flux's layers with tied weights ( https://github.com/FluxML/Flux.jl/issues/1592).
So first of all I wanted to check how Zygote handles aliased objects. Here are 6 examples. Maybe it's all expected and intended but I find the last 3 in particular a bit surprising. @oxinabox is this what we want?
julia> using Zygote
julia> x = [1]
1-element Vector{Int64}:
1
julia> xt = x'
1×1 adjoint(::Vector{Int64}) with eltype Int64:
1
# 1.
julia> gradient(() -> sum(x' .* x), Params([x])).grads
IdDict{Any, Any} with 2 entries:
:(Main.x) => [2]
[1] => [2]
# 2.
julia> gradient(() -> sum(xt .* x), Params([x])).grads
IdDict{Any, Any} with 3 entries:
:(Main.x) => [1]
[1] => [1]
:(Main.xt) => [1]
# 3.
julia> gradient(() -> sum(xt .* x), Params([x,xt])).grads
IdDict{Any, Any} with 4 entries:
[1] => [1]
:(Main.x) => [1]
[1] => [1]
:(Main.xt) => [1]
# 4.
julia> gradient(() -> sum(xt.parent .* x), Params([x])).grads
IdDict{Any, Any} with 2 entries:
:(Main.x) => [1]
[1] => [2]
# 5.
julia> gradient(() -> sum(xt.parent .* x), Params([x, xt])).grads
IdDict{Any, Any} with 3 entries:
[1] => nothing # this is xt
:(Main.x) => [1]
[1] => [2] # this is x
#6.
julia> gradient(() -> sum(xt.parent .* x), Params([xt])).grads
IdDict{Any, Any} with 3 entries:
[1] => (parent = [1],)
:(Main.x) => [1]
:(Main.xt) => (parent = [1],)
I guess the most disturbing is 5., shouldn't return
[1] => (parent = [1],) # this is xt
:(Main.x) => [1]
[1] => [1] # this is x
instead?
putting aliased memory in Params feels like its not going to be ok. I would need a fair bit of time to think about these.
(never mind, the thing I was missing is scribbling the wrong variables on my napkin)
For a user define struct we have
julia> struct A; x; end
julia> x = rand(2); a = A(x);
julia> Base.sum(a) = sum(a.x)
julia> gradient(() -> sum(a), Params([x])).grads
IdDict{Any, Any} with 1 entry:
[0.573261, 0.457937] => 2-element Fill{Float64}: entries equal to 1.0
while for Adjoint something is wrong
julia> xt = Adjoint(x)
1×2 adjoint(::Vector{Float64}) with eltype Float64:
0.573261 0.457937
julia> gradient(() -> sum(xt), Params([x])).grads
IdDict{Any, Any} with 2 entries:
[0.573261, 0.457937] => nothing
:(Main.xt) => 1×2 Fill{Float64}: entries equal to 1.0
This seems expected... The grads actually also track global params as a GlobalRef to capture tied variables.
They make sense, but that doesn't make them right/useful.
I tried creating similar problems with explicit params yesterday, and I just could not find an example that didn't work. So rather than spend time fixing this issue, we could transition to explicit params across the ecosystem.
Seems hard to not consider last example in https://github.com/FluxML/Zygote.jl/issues/991#issuecomment-864375988 a bug. I don't even know precisely why it happens, probably when we hit an AbstractArray{<:Number}
in Zygote we don't look for internal structure, is that the case?
I tried creating similar problems with explicit params yesterday, and I just could not find an example that didn't work. So rather than spend time fixing this issue, we could transition to explicit params across the ecosystem.
I'm not totally sure explicit gradient is a convenient fit for every situation, I'd like to see a diverse set of use cases where it replaces params
. In last example, explicit gradient is at least consistent , although not quite useful
julia> gradient(x -> sum(a), x)
(nothing,)
julia> gradient(x -> sum(xt), x)
(nothing,)
I think this illustrates why I consider explicit params better. It's obvious why the last example returned nothing
. For the same reason, the Adjoint
case returns nothing, but it is less obvious because we expect implicit params to pick up connections that aren't there in the function being differentiated.
One option is to add some kind of post-processing step where Params
finds these connections and applies a fix. But I feel that it hard to do in the generic case correctly.
For example, something like https://github.com/FluxML/Flux.jl/issues/1592 works out nicely. Similar to the examples above, if we have
m1 = Dense(5, 2)
m2 = Dense(transpose(m1.weight))
m = Chain(m1, m2)
dm = gradient(m -> sum(m(ones(Float32, 5))), m)[1]
Zygote will see the weight of m1
as w1 = w
and m2
as w2 = transpose(w)
. It returns gradients w.r.t. w1
and w2
(as if they are not tied). But when we consider the part that Zygote doesn't see (w1 = w
), we have
from multivariate chain rule
dL/dw = dL/dw1 * dw1/dw + dL/dw2 * dw2/dw
dw1/dw = 1
dw2/dw = 1 (up to transpose)
=> dL/dw = dL/dw1 + dL/dw2
The last equation is automatically done by simple optimizers like gradient descent provided you use lazy wrappers like transpose
or views. (@oxinabox can correct me if I am wrong here, my AD knowledge is very limited).
I guess it isn't automatic for complex optimizers that track momentum, etc. But it seems like then we should be handling it on the optimizer side, not the AD. This is where I think explicit params is nicer. What I wrote above is true for implicit params as well (e.g. Example 3 in the main issue) when Params
contains x, xt
. The trouble with implicit params is that you get all these other cases, issues with hashing, etc. that make dealing with the final equation I wrote above harder on the optimizer side.