Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

Reverse over ForwardDiff differentiation: Zygote vs. ReverseDiff

Open nmheim opened this issue 3 years ago • 5 comments

I would like to compute the gradient of a function that contains a ForwardDiff.gradient/ForwardDiff.derivative. Computing that higher order gradient with ForwardDiff works, but will be slow for my case be cause its a many input function. MWE below.

using ForwardDiff
using ReverseDiff
using Zygote

f(x::Real,y::AbstractVector) = x*sum(y)

df(x,y) = ForwardDiff.derivative(x->f(x,y), x)

# this returns (nothing,)
Zygote.gradient(y->df(0.1,y), rand(5))

# this works; returning vector of ones
ForwardDiff.gradient(y->df(0.1,y), rand(5))

# this works as well
ReverseDiff.gradient(y->df(0.1,y), rand(5))

It would be really great to get this to work, because for my use case I have to deal with complex numbers and it seems like ReverseDiff does not like them that much.

Tagging @mcabbott as we have discussed this on slack. I would be happy to help in any way I can.

nmheim avatar May 06 '22 23:05 nmheim

As this was part of the discussion on slack: I am getting the same behaviour when using ForwardDiff.gradient inside

f(x::AbstractVector,y::AbstractVector) = sum(x)*sum(y)

df(x,y) = ForwardDiff.gradient(x->f(x,y), x)[1]

x = [0.1]
y = rand(5)

# this returns (nothing,)
Zygote.gradient(y->df(x,y), y)

# this works; returning vector of ones
ForwardDiff.gradient(y->df(x,y), y)

# this works as well
ReverseDiff.gradient(y->df(x,y), y)

nmheim avatar May 06 '22 23:05 nmheim

Thanks for making an issue.

The nothing comes from #968, which sends these to Zygote.forwarddiff(f, x), forward over forward, which does not, cannot, track variables closed over in its f. Only IMO the minimal solution here is to remove that, and return some of these to errors, as the silent nothing is too often surprising.

In #769 there was also a proposal to turn Zygote over ForwardDiff around automatically to ForwardDiff over Zygote. Not sure whether this can handle all cases, or be made to work really.

Alternatively, maybe we can figure out what ReverseDiff is doing. The last example is the surprising one, since ForwardDiff.gradient inside makes an array for its output and then writes into it.

Xref #1189 about this same problem (and others).

mcabbott avatar May 07 '22 00:05 mcabbott

Ok I see. ReverseDiff is dropping down to scalar TrackedReal, which it can trace individually through mutation:

julia> ReverseDiff.gradient([1,2,3.0]) do x
         y = zeros(@show(eltype(x)), 3)
         y[1] = x[1] + x[2]^2
         sum(y)
       end
eltype(x) = ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}
3-element Vector{Float64}:
 1.0
 4.0
 0.0

Second derivatives:

julia> ForwardDiff.gradient([1, 2]) do x
         ForwardDiff.gradient(x) do y
           (@show(x[1]) * @show(y[2]))^2
         end |> sum
       end
x[1] = Dual{ForwardDiff.Tag{var"#85#87", Int64}}(1,1,0)
y[2] = Dual{ForwardDiff.Tag{var"#86#88"{Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#85#87", Int64}, Int64, 2}}}, ForwardDiff.Dual{ForwardDiff.Tag{var"#85#87", Int64}, Int64, 2}}}(Dual{ForwardDiff.Tag{var"#85#87", Int64}}(2,0,1),Dual{ForwardDiff.Tag{var"#85#87", Int64}}(0,0,0),Dual{ForwardDiff.Tag{var"#85#87", Int64}}(1,0,0))
2-element Vector{Int64}:
 8
 2

julia> ReverseDiff.gradient([1, 2]) do x
         ForwardDiff.gradient(x) do y
           (@show(x[1]) * @show(y[2]))^2
         end |> sum
       end
x[1] = TrackedReal<7VO>(1, 0, DVN, 1, 890)
y[2] = Dual{ForwardDiff.Tag{var"#90#92"{ReverseDiff.TrackedArray{Int64, Int64, 1, Vector{Int64}, Vector{Int64}}}, ReverseDiff.TrackedReal{Int64, Int64, ReverseDiff.TrackedArray{Int64, Int64, 1, Vector{Int64}, Vector{Int64}}}}}(TrackedReal<Le4>(2, 0, DVN, 2, 890),TrackedReal<7aD>(0, 0, ---, ---),TrackedReal<G74>(1, 0, ---, ---))
2-element Vector{Int64}:
 8
 2

julia> Zygote.gradient([1, 2]) do x
         ForwardDiff.gradient(x) do y
           (@show(x[1]) * @show(y[2]))^2
         end |> sum
       end
x[1] = 1
y[2] = Dual{ForwardDiff.Tag{var"#94#96"{Vector{Int64}}, ForwardDiff.Dual{Nothing, Int64, 2}}}(Dual{Nothing}(2,0,1),Dual{Nothing}(0,0,0),Dual{Nothing}(1,0,0))
([0.0, 2.0],)

mcabbott avatar May 10 '22 02:05 mcabbott

Is there any chance that this will make it into Zygote?

nmheim avatar May 10 '22 08:05 nmheim

Zygote uses a completely different mechanism for its AD, so I doubt ReverseDiff's approach would be directly applicable (if at all).

ToucheSir avatar May 11 '22 02:05 ToucheSir