Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

Target ForwardDiff's public API?

Open ChrisRackauckas opened this issue 5 years ago • 2 comments

Instead of doing the harder thing of making Duals work in general, why not target the higher level API? I played around a bit, didn't quite get it, but someone might want to take it from here:

using ForwardDiff, Zygote, ZygoteRules, FiniteDiff, Test, Adapt

ZygoteRules.@adjoint function ForwardDiff.derivative(f,x)
    der = ForwardDiff.derivative(f,x)
    function derivative_adjoint(Δ)
        function _f(y)
            out,back = Zygote.pullback(f,y)
            back(Δ)[1]
        end
        (nothing,ForwardDiff.derivative(_f,x))
    end
    der, derivative_adjoint
end

ZygoteRules.@adjoint function ForwardDiff.gradient(f,x)
    grad = ForwardDiff.gradient(f,x)
    function gradient_adjoint(Δ)
        function _f(y)
            out,back = Zygote.pullback(f,y)
            back(Δ)[1]
        end
        (nothing,ForwardDiff.gradient(_f,x))
    end
    grad, gradient_adjoint
end

ZygoteRules.@adjoint function ForwardDiff.jacobian(f,x)
    jac = ForwardDiff.jacobian(f,x)
    function jacobian_adjoint(Δ)
        function _f(y)
            out,back = Zygote.pullback(f,y)
            vec(back(Δ)[1])
        end
        (nothing,ForwardDiff.jacobian(_f,x))
    end
    jac, jacobian_adjoint
end

f(x) = 2x^2 + x
g(x) = ForwardDiff.derivative(f,x)
out,back = Zygote.pullback(g,2.0)
stakehouse = back(1)[1]
@test typeof(stakehouse) <: Float64
@test stakehouse[1] == ForwardDiff.derivative(g,2.0)

f(x) = [2x[1]^2 + x[1],x[2]^2 * x[1]]
g(x) = sum(ForwardDiff.jacobian(f,x))
out,back = Zygote.pullback(g,[2.0,3.2])
stakehouse = back(1.0)[1]
@test typeof(stakehouse) <: Vector
@test size(stakehouse) == (2,)
@test stakehouse == ForwardDiff.gradient(g,[2.0,3.2])

g(x) = prod(ForwardDiff.jacobian(f,x))
out,back = Zygote.pullback(g,[2.0,3.2])
stakehouse = back(1.0)[1]
@test typeof(stakehouse) <: Vector
@test size(stakehouse) == (2,)
@test stakehouse == ForwardDiff.gradient(g,[2.0,3.2])

g(x) = sum(abs2,ForwardDiff.jacobian(f,x))
out,back = Zygote.pullback(g,[2.0,3.2])
stakehouse = back(1.0)[1]
@test typeof(stakehouse) <: Vector
@test size(stakehouse) == (2,)
@test stakehouse == ForwardDiff.gradient(g,[2.0,3.2])

ChrisRackauckas avatar Aug 16 '20 21:08 ChrisRackauckas

This came up when a user tried to use ForwardDiff.derivative in the loss function with GalacticOptim, Zygote fails in that case. On further discussion on slack and with @mcabbott's help we got it working (with some modifications)

julia> using Zygote, ForwardDiff
julia> g(t) = t .* ones(size(x0)...)
g (generic function with 1 method)
julia> dg(t) = ForwardDiff.jacobian(g,[t])
dg (generic function with 1 method)
julia> f(x,p) =  sum(abs2, sum(x + dg(p[1])))
f (generic function with 1 method)
julia> x0 = zeros(10)
10-element Vector{Float64}:
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
julia> p  = [1.]
1-element Vector{Float64}:
 1.0
julia> Zygote.gradient(x -> f(x, p), x0)
ERROR: Mutating arrays is not supported
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.var"#407#408")(#unused#::Nothing)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/lib/array.jl:61
  [3] (::Zygote.var"#2269#back#409"{Zygote.var"#407#408"})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [4] Pullback
    @ ./broadcast.jl:894 [inlined]
  [5] Pullback
    @ ./broadcast.jl:891 [inlined]
  [6] Pullback
    @ ./broadcast.jl:887 [inlined]
  [7] Pullback
    @ ~/.julia/packages/ForwardDiff/QOqCN/src/jacobian.jl:119 [inlined]
  [8] (::typeof(∂(extract_jacobian!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
  [9] Pullback
    @ ~/.julia/packages/ForwardDiff/QOqCN/src/jacobian.jl:150 [inlined]
 [10] (::typeof(∂(vector_mode_jacobian)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [11] Pullback
    @ ~/.julia/packages/ForwardDiff/QOqCN/src/jacobian.jl:21 [inlined]
 [12] (::typeof(∂(jacobian)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [13] Pullback (repeats 2 times)
    @ ~/.julia/packages/ForwardDiff/QOqCN/src/jacobian.jl:19 [inlined]
 [14] (::typeof(∂(jacobian)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [15] Pullback
    @ ./REPL[5]:1 [inlined]
 [16] (::typeof(∂(dg)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [17] Pullback
    @ ./REPL[6]:1 [inlined]
 [18] (::typeof(∂(f)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [19] Pullback
    @ ./REPL[10]:1 [inlined]
 [20] (::typeof(∂(#3)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [21] (::Zygote.var"#41#42"{typeof(∂(#3))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:41
 [22] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:59
 [23] top-level scope
    @ REPL[10]:1
julia> 
julia> @eval Zygote begin
              @adjoint ForwardDiff.gradient(f, x) = pullback(forwarddiff, x -> ForwardDiff.gradient(f, x), x)
               @adjoint ForwardDiff.jacobian(f, x) = pullback(forwarddiff, x -> ForwardDiff.jacobian(f, x), x)
              end
julia> Zygote.gradient(x -> f(x, p), x0)
(10×1 Fill{Float64}: entries equal to 20.0,)
julia> ForwardDiff.gradient(x -> f(x,p), x0)
10-element Vector{Float64}:
 20.0
 20.0
 20.0
 20.0
 20.0
 20.0
 20.0
 20.0
 20.0
 20.0

Vaibhavdixit02 avatar May 09 '21 15:05 Vaibhavdixit02

The example above should now just work, on Zygote#master.

However, this deserves a link: https://github.com/FluxML/Zygote.jl/issues/953#issuecomment-841882071: Implicit parameters inside the function you pass to ForwardDiff.xxx will be ignored. I think that's also true of the (forward over reverse) implementations proposed above.

mcabbott avatar May 17 '21 15:05 mcabbott