Zygote.jl
Zygote.jl copied to clipboard
Reverse over ForwardDiff differentiation: Zygote vs. ReverseDiff
I would like to compute the gradient of a function that contains a ForwardDiff.gradient/ForwardDiff.derivative.
Computing that higher order gradient with ForwardDiff works, but will be slow for my case be cause its a many input function.
MWE below.
using ForwardDiff
using ReverseDiff
using Zygote
f(x::Real,y::AbstractVector) = x*sum(y)
df(x,y) = ForwardDiff.derivative(x->f(x,y), x)
# this returns (nothing,)
Zygote.gradient(y->df(0.1,y), rand(5))
# this works; returning vector of ones
ForwardDiff.gradient(y->df(0.1,y), rand(5))
# this works as well
ReverseDiff.gradient(y->df(0.1,y), rand(5))
It would be really great to get this to work, because for my use case I have to deal with complex numbers and it seems like ReverseDiff does not like them that much.
Tagging @mcabbott as we have discussed this on slack. I would be happy to help in any way I can.
As this was part of the discussion on slack: I am getting the same behaviour when using ForwardDiff.gradient inside
f(x::AbstractVector,y::AbstractVector) = sum(x)*sum(y)
df(x,y) = ForwardDiff.gradient(x->f(x,y), x)[1]
x = [0.1]
y = rand(5)
# this returns (nothing,)
Zygote.gradient(y->df(x,y), y)
# this works; returning vector of ones
ForwardDiff.gradient(y->df(x,y), y)
# this works as well
ReverseDiff.gradient(y->df(x,y), y)
Thanks for making an issue.
The nothing comes from #968, which sends these to Zygote.forwarddiff(f, x), forward over forward, which does not, cannot, track variables closed over in its f. Only IMO the minimal solution here is to remove that, and return some of these to errors, as the silent nothing is too often surprising.
In #769 there was also a proposal to turn Zygote over ForwardDiff around automatically to ForwardDiff over Zygote. Not sure whether this can handle all cases, or be made to work really.
Alternatively, maybe we can figure out what ReverseDiff is doing. The last example is the surprising one, since ForwardDiff.gradient inside makes an array for its output and then writes into it.
Xref #1189 about this same problem (and others).
Ok I see. ReverseDiff is dropping down to scalar TrackedReal, which it can trace individually through mutation:
julia> ReverseDiff.gradient([1,2,3.0]) do x
y = zeros(@show(eltype(x)), 3)
y[1] = x[1] + x[2]^2
sum(y)
end
eltype(x) = ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}
3-element Vector{Float64}:
1.0
4.0
0.0
Second derivatives:
julia> ForwardDiff.gradient([1, 2]) do x
ForwardDiff.gradient(x) do y
(@show(x[1]) * @show(y[2]))^2
end |> sum
end
x[1] = Dual{ForwardDiff.Tag{var"#85#87", Int64}}(1,1,0)
y[2] = Dual{ForwardDiff.Tag{var"#86#88"{Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#85#87", Int64}, Int64, 2}}}, ForwardDiff.Dual{ForwardDiff.Tag{var"#85#87", Int64}, Int64, 2}}}(Dual{ForwardDiff.Tag{var"#85#87", Int64}}(2,0,1),Dual{ForwardDiff.Tag{var"#85#87", Int64}}(0,0,0),Dual{ForwardDiff.Tag{var"#85#87", Int64}}(1,0,0))
2-element Vector{Int64}:
8
2
julia> ReverseDiff.gradient([1, 2]) do x
ForwardDiff.gradient(x) do y
(@show(x[1]) * @show(y[2]))^2
end |> sum
end
x[1] = TrackedReal<7VO>(1, 0, DVN, 1, 890)
y[2] = Dual{ForwardDiff.Tag{var"#90#92"{ReverseDiff.TrackedArray{Int64, Int64, 1, Vector{Int64}, Vector{Int64}}}, ReverseDiff.TrackedReal{Int64, Int64, ReverseDiff.TrackedArray{Int64, Int64, 1, Vector{Int64}, Vector{Int64}}}}}(TrackedReal<Le4>(2, 0, DVN, 2, 890),TrackedReal<7aD>(0, 0, ---, ---),TrackedReal<G74>(1, 0, ---, ---))
2-element Vector{Int64}:
8
2
julia> Zygote.gradient([1, 2]) do x
ForwardDiff.gradient(x) do y
(@show(x[1]) * @show(y[2]))^2
end |> sum
end
x[1] = 1
y[2] = Dual{ForwardDiff.Tag{var"#94#96"{Vector{Int64}}, ForwardDiff.Dual{Nothing, Int64, 2}}}(Dual{Nothing}(2,0,1),Dual{Nothing}(0,0,0),Dual{Nothing}(1,0,0))
([0.0, 2.0],)
Is there any chance that this will make it into Zygote?
Zygote uses a completely different mechanism for its AD, so I doubt ReverseDiff's approach would be directly applicable (if at all).