Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

no gradients if we save the Flux.params into a variable

Open MariusDrulea opened this issue 2 years ago • 5 comments

See the following MWE:

using Flux

model = Dense(2, 2)

xt = rand(Float32, 2, 4) # batch size of 4
yt = rand(Float32, 2, 4)

ps = Flux.params(model)
loss_fun(m, x, y) = 1/2*sum(p->sum(p.^2), ps)

loss_fun_explicit(m, x, y) = 1/2*sum(m.weight.^2) + 1/2*sum(m.bias.^2)

loss_fun_slow(m, x, y) = 1/2*sum(p->sum(p.^2), Flux.params(m))

∇m = gradient(m->loss_fun(m, xt, yt), model)    
∇m_explicit = gradient(m->loss_fun_explicit(m, xt, yt), model)    
∇m_slow = gradient(m->loss_fun_slow(m, xt, yt), model)    

@show ∇m
@show ∇m_explicit
@show ∇m_slow

The values of the gradients are bellow. ∇m_explicit and ∇m_slow are equal and correct, but ∇m is nothing.

∇m = (nothing,)
∇m_explicit = ((weight = Float32[0.69311625 -1.0913904; -0.12783962 -0.15561718], bias = Float32[0.0, 0.0], σ = nothing),)
∇m_slow = ((weight = Float32[0.69311625 -1.0913904; -0.12783962 -0.15561718], bias = Float32[0.0, 0.0], σ = nothing),)

MariusDrulea avatar Dec 27 '22 18:12 MariusDrulea

Edit after posting the issue. The behavior might be the expected one. I think it must be the case as we want to use only explicit loss functions.

MariusDrulea avatar Dec 27 '22 21:12 MariusDrulea

It looks like the explicit parameters version is actually correct and the other two are wrong, because they give the same answer when you remove the regularization term. I'm trying to figure out why no gradient is being returned for params, because https://github.com/FluxML/Flux.jl/pull/2118 was explicitly written to allow AD when using explicit params.

ToucheSir avatar Dec 27 '22 21:12 ToucheSir

@ToucheSir just noticed the correct way to call the implicit function is like this ∇m = gradient(()->loss_fun(), ps), we have to provide a function with no arguments and also the ps variable. If I do so, I get the same gradients values as for the explicit versions.

MariusDrulea avatar Dec 27 '22 21:12 MariusDrulea

Yes, but as seen in your edited example you can also call params on an explicit model. The trouble comes when you try to iterate over an external (in this case global) variable such as ps or model, because Zygote can't see a path back from those to any of the inputs. The question is whether we can catch such accesses and warn/error as appropriate. My only idea so far is to add a warning which shows up when differentiating params in explicit mode that links to a docs section outlining what works and what doesn't.

ToucheSir avatar Dec 27 '22 22:12 ToucheSir

The answers presently above look correct to me.

Perhaps a simpler example of what they illustrate is this. None of these seem wrong, but the ones mixing explicit arguments and global references are perhaps surprising.

julia> using Zygote, LinearAlgebra, ForwardDiff

julia> v = [2.0, 3.0];

julia> gradient(x -> dot(x,x), v)
([4.0, 6.0],)

julia> gradient(x -> dot(x,v), v)  # one global reference
([2.0, 3.0],)

julia> ForwardDiff.gradient(x -> dot(x,v), v)  # agrees
2-element Vector{Float64}:
 2.0
 3.0

julia> gradient(x -> dot(v,v), v)  # two global references
(nothing,)

julia> ForwardDiff.gradient(x -> dot(v,v), v)  # agrees
2-element Vector{Float64}:
 0.0
 0.0

julia> gradient(() -> dot(v,v), Params([v]))  # implicit mode
Grads(...)

julia> ans[v]  # same answer as the first, but via global ref.
2-element Vector{Float64}:
 4.0
 6.0

mcabbott avatar Dec 28 '22 14:12 mcabbott