Zygote.jl
Zygote.jl copied to clipboard
no gradients if we save the Flux.params into a variable
See the following MWE:
using Flux
model = Dense(2, 2)
xt = rand(Float32, 2, 4) # batch size of 4
yt = rand(Float32, 2, 4)
ps = Flux.params(model)
loss_fun(m, x, y) = 1/2*sum(p->sum(p.^2), ps)
loss_fun_explicit(m, x, y) = 1/2*sum(m.weight.^2) + 1/2*sum(m.bias.^2)
loss_fun_slow(m, x, y) = 1/2*sum(p->sum(p.^2), Flux.params(m))
∇m = gradient(m->loss_fun(m, xt, yt), model)
∇m_explicit = gradient(m->loss_fun_explicit(m, xt, yt), model)
∇m_slow = gradient(m->loss_fun_slow(m, xt, yt), model)
@show ∇m
@show ∇m_explicit
@show ∇m_slow
The values of the gradients are bellow. ∇m_explicit
and ∇m_slow
are equal and correct, but ∇m
is nothing
.
∇m = (nothing,)
∇m_explicit = ((weight = Float32[0.69311625 -1.0913904; -0.12783962 -0.15561718], bias = Float32[0.0, 0.0], σ = nothing),)
∇m_slow = ((weight = Float32[0.69311625 -1.0913904; -0.12783962 -0.15561718], bias = Float32[0.0, 0.0], σ = nothing),)
Edit after posting the issue. The behavior might be the expected one. I think it must be the case as we want to use only explicit loss functions.
It looks like the explicit parameters version is actually correct and the other two are wrong, because they give the same answer when you remove the regularization term. I'm trying to figure out why no gradient is being returned for params
, because https://github.com/FluxML/Flux.jl/pull/2118 was explicitly written to allow AD when using explicit params.
@ToucheSir just noticed the correct way to call the implicit function is like this ∇m = gradient(()->loss_fun(), ps)
, we have to provide a function with no arguments and also the ps
variable. If I do so, I get the same gradients values as for the explicit versions.
Yes, but as seen in your edited example you can also call params
on an explicit model. The trouble comes when you try to iterate over an external (in this case global) variable such as ps
or model
, because Zygote can't see a path back from those to any of the inputs. The question is whether we can catch such accesses and warn/error as appropriate. My only idea so far is to add a warning which shows up when differentiating params
in explicit mode that links to a docs section outlining what works and what doesn't.
The answers presently above look correct to me.
Perhaps a simpler example of what they illustrate is this. None of these seem wrong, but the ones mixing explicit arguments and global references are perhaps surprising.
julia> using Zygote, LinearAlgebra, ForwardDiff
julia> v = [2.0, 3.0];
julia> gradient(x -> dot(x,x), v)
([4.0, 6.0],)
julia> gradient(x -> dot(x,v), v) # one global reference
([2.0, 3.0],)
julia> ForwardDiff.gradient(x -> dot(x,v), v) # agrees
2-element Vector{Float64}:
2.0
3.0
julia> gradient(x -> dot(v,v), v) # two global references
(nothing,)
julia> ForwardDiff.gradient(x -> dot(v,v), v) # agrees
2-element Vector{Float64}:
0.0
0.0
julia> gradient(() -> dot(v,v), Params([v])) # implicit mode
Grads(...)
julia> ans[v] # same answer as the first, but via global ref.
2-element Vector{Float64}:
4.0
6.0