Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

Gradient disappear when there is an indexed-array in the loss function

Open huangyxi opened this issue 3 years ago • 5 comments

Hi, I am going to train some parameters specified by labels in a neural network. I found that when using an indexed array in the loss function, the corresponding gradient appears to be nothing. Here is a minimal working example:

using Flux

x1 = [0., 1.]
x = [x1]
x_real = x[1]
p = [0.5, 0.6]

function loss()
    y = x[1] + p # line {1}
    sum(abs2, y .- 1)
end

ps = Flux.params(x[1], p)
gs = Flux.gradient(loss, ps)

I have tried that the gradient descent is performed normally whether line {1} be y = x[1] # {2} or y = x_real + p # {3}. However, when line {1} is y = x[1] + p (# {2} or # {3} is not suitable for our condition), the results of the gradient gs.grads is shown below:

IdDict{Any, Any} with 4 entries:
  [0.0, 1.0] => nothing
  [0.5, 0.6] => [-1.0, 1.2]
  :(Main.x)  => [[-1.0, 1.2]]
  :(Main.p)  => [-1.0, 1.2]

which indicates that the corresponding gradient of x1 is not been performed.

huangyxi avatar May 23 '22 06:05 huangyxi

Somehow this confuses the implicit parameter handling: you add x1 to Params, but you access it via a different global variable, x, which is not in Params. I guess that's a bug.

Can you avoid it by always storing and accessing the same global variable? (Or, even better, by avoiding implicit parameters entirely.)

julia> using Zygote  # doesn't need Flux

julia> ps = Zygote.Params([x[1], p])  # as above
Params([[0.0, 1.0], [0.5, 0.6]])

julia> gs = Zygote.gradient(loss, ps)
Grads(...)

julia> gs[x1] === nothing
true

julia> x = [x1, x1.+1]
2-element Vector{Vector{Float64}}:
 [0.0, 1.0]
 [1.0, 2.0]
 
julia> ps2 = Zygote.Params([x, p])  # store the outer array in Params
Params([[[0.0, 1.0], [1.0, 2.0]], [0.5, 0.6]])

julia> gs2 = Zygote.gradient(loss, ps2)
Grads(...)

julia> gs2[x]  # fine
2-element Vector{Union{Nothing, Vector{Float64}}}:
 [-1.0, 1.2000000000000002]
 nothing

julia> gs2[x1]
ERROR: KeyError: key [0.0, 1.0] not found

julia> ps3 = Zygote.Params([x, x1, p])  # storing both does not help
Params([[[0.0, 1.0], [1.0, 2.0]], [0.0, 1.0], [0.5, 0.6]])

julia> gs3 = Zygote.gradient(loss, ps3)
Grads(...)

julia> gs3[x]
2-element Vector{Union{Nothing, Vector{Float64}}}:
 [-1.0, 1.2000000000000002]
 nothing

julia> gs3[x1] === nothing
true

mcabbott avatar May 23 '22 13:05 mcabbott

This issue arises because we don't do per-element tracking of implicit gradients for arrays of arrays. This is currently done for tuples, so it may be possible to use a similar pattern for arrays. As Michael mentioned though, I would highly recommend avoiding implicit params if you can. The example above will be both more efficient and less surprising with "explicit" params.

ToucheSir avatar May 23 '22 16:05 ToucheSir

Thanks Michael Abbott and Brian Chen. According to the two responses above, two alternative solutions to this problem have been found currently:

Store outer params:

x = [x1, x1.+1]
ps = Zygote.Params([x, p])

Store as Tuple:

x = (x1, x1.+1)
ps = Zygote.Params([x[1], p])

A simplification closer to our real situation would be store inner arrays as Dict values, since the number of inner arrays is variable and it is not possible to assign a fixed number of variables in the code. Though this problem have been solved thanks to you, it would be nice if the elements of collections more than Tuple could be supported as params directly in the future.

huangyxi avatar May 24 '22 12:05 huangyxi

The alternative we're referring to with explicit params is neither of those, but this:

function loss(x, p)
    y = x[1] + p
    sum(abs2, y .- 1)
end

dx, dp = Flux.gradient(loss, x, p)

Or:

function loss(x1, p)
    y = x1 + p
    sum(abs2, y .- 1)
end

dx_real, dp = Flux.gradient(loss, x_real, p)

In other words, you pass in anything you want to get a gradient for/differentiate with respect to. This avoids all the issues mentioned in this thread and should be slightly more efficient too. You should almost never have to use params or Params with Zygote unless you're relying on a higher-level library which requires them.

ToucheSir avatar May 24 '22 15:05 ToucheSir

Thanks for your explanations and suggestions.

huangyxi avatar May 26 '22 03:05 huangyxi