Zygote.jl
Zygote.jl copied to clipboard
Target ForwardDiff's public API?
Instead of doing the harder thing of making Duals work in general, why not target the higher level API? I played around a bit, didn't quite get it, but someone might want to take it from here:
using ForwardDiff, Zygote, ZygoteRules, FiniteDiff, Test, Adapt
ZygoteRules.@adjoint function ForwardDiff.derivative(f,x)
der = ForwardDiff.derivative(f,x)
function derivative_adjoint(Δ)
function _f(y)
out,back = Zygote.pullback(f,y)
back(Δ)[1]
end
(nothing,ForwardDiff.derivative(_f,x))
end
der, derivative_adjoint
end
ZygoteRules.@adjoint function ForwardDiff.gradient(f,x)
grad = ForwardDiff.gradient(f,x)
function gradient_adjoint(Δ)
function _f(y)
out,back = Zygote.pullback(f,y)
back(Δ)[1]
end
(nothing,ForwardDiff.gradient(_f,x))
end
grad, gradient_adjoint
end
ZygoteRules.@adjoint function ForwardDiff.jacobian(f,x)
jac = ForwardDiff.jacobian(f,x)
function jacobian_adjoint(Δ)
function _f(y)
out,back = Zygote.pullback(f,y)
vec(back(Δ)[1])
end
(nothing,ForwardDiff.jacobian(_f,x))
end
jac, jacobian_adjoint
end
f(x) = 2x^2 + x
g(x) = ForwardDiff.derivative(f,x)
out,back = Zygote.pullback(g,2.0)
stakehouse = back(1)[1]
@test typeof(stakehouse) <: Float64
@test stakehouse[1] == ForwardDiff.derivative(g,2.0)
f(x) = [2x[1]^2 + x[1],x[2]^2 * x[1]]
g(x) = sum(ForwardDiff.jacobian(f,x))
out,back = Zygote.pullback(g,[2.0,3.2])
stakehouse = back(1.0)[1]
@test typeof(stakehouse) <: Vector
@test size(stakehouse) == (2,)
@test stakehouse == ForwardDiff.gradient(g,[2.0,3.2])
g(x) = prod(ForwardDiff.jacobian(f,x))
out,back = Zygote.pullback(g,[2.0,3.2])
stakehouse = back(1.0)[1]
@test typeof(stakehouse) <: Vector
@test size(stakehouse) == (2,)
@test stakehouse == ForwardDiff.gradient(g,[2.0,3.2])
g(x) = sum(abs2,ForwardDiff.jacobian(f,x))
out,back = Zygote.pullback(g,[2.0,3.2])
stakehouse = back(1.0)[1]
@test typeof(stakehouse) <: Vector
@test size(stakehouse) == (2,)
@test stakehouse == ForwardDiff.gradient(g,[2.0,3.2])
This came up when a user tried to use ForwardDiff.derivative in the loss function with GalacticOptim, Zygote fails in that case. On further discussion on slack and with @mcabbott's help we got it working (with some modifications)
julia> using Zygote, ForwardDiff
julia> g(t) = t .* ones(size(x0)...)
g (generic function with 1 method)
julia> dg(t) = ForwardDiff.jacobian(g,[t])
dg (generic function with 1 method)
julia> f(x,p) = sum(abs2, sum(x + dg(p[1])))
f (generic function with 1 method)
julia> x0 = zeros(10)
10-element Vector{Float64}:
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
julia> p = [1.]
1-element Vector{Float64}:
1.0
julia> Zygote.gradient(x -> f(x, p), x0)
ERROR: Mutating arrays is not supported
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] (::Zygote.var"#407#408")(#unused#::Nothing)
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/lib/array.jl:61
[3] (::Zygote.var"#2269#back#409"{Zygote.var"#407#408"})(Δ::Nothing)
@ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
[4] Pullback
@ ./broadcast.jl:894 [inlined]
[5] Pullback
@ ./broadcast.jl:891 [inlined]
[6] Pullback
@ ./broadcast.jl:887 [inlined]
[7] Pullback
@ ~/.julia/packages/ForwardDiff/QOqCN/src/jacobian.jl:119 [inlined]
[8] (::typeof(∂(extract_jacobian!)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
[9] Pullback
@ ~/.julia/packages/ForwardDiff/QOqCN/src/jacobian.jl:150 [inlined]
[10] (::typeof(∂(vector_mode_jacobian)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
[11] Pullback
@ ~/.julia/packages/ForwardDiff/QOqCN/src/jacobian.jl:21 [inlined]
[12] (::typeof(∂(jacobian)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
[13] Pullback (repeats 2 times)
@ ~/.julia/packages/ForwardDiff/QOqCN/src/jacobian.jl:19 [inlined]
[14] (::typeof(∂(jacobian)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
[15] Pullback
@ ./REPL[5]:1 [inlined]
[16] (::typeof(∂(dg)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
[17] Pullback
@ ./REPL[6]:1 [inlined]
[18] (::typeof(∂(f)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
[19] Pullback
@ ./REPL[10]:1 [inlined]
[20] (::typeof(∂(#3)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
[21] (::Zygote.var"#41#42"{typeof(∂(#3))})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:41
[22] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:59
[23] top-level scope
@ REPL[10]:1
julia>
julia> @eval Zygote begin
@adjoint ForwardDiff.gradient(f, x) = pullback(forwarddiff, x -> ForwardDiff.gradient(f, x), x)
@adjoint ForwardDiff.jacobian(f, x) = pullback(forwarddiff, x -> ForwardDiff.jacobian(f, x), x)
end
julia> Zygote.gradient(x -> f(x, p), x0)
(10×1 Fill{Float64}: entries equal to 20.0,)
julia> ForwardDiff.gradient(x -> f(x,p), x0)
10-element Vector{Float64}:
20.0
20.0
20.0
20.0
20.0
20.0
20.0
20.0
20.0
20.0
The example above should now just work, on Zygote#master.
However, this deserves a link: https://github.com/FluxML/Zygote.jl/issues/953#issuecomment-841882071: Implicit parameters inside the function you pass to ForwardDiff.xxx will be ignored. I think that's also true of the (forward over reverse) implementations proposed above.