ReverseDiff.jl
ReverseDiff.jl copied to clipboard
@grad_from_chainrules macro fails when using multi-output functions
Dear team,
first: Thanks for developing this nice package :-)
I think there is an error with the macro @grad_from_chainrules when using it on multi-output functions (for example a function that outputs a tuple of two vectors). Note, that gradient/jacobian determination is not part of the current Github-tests, only the rrules are evaluated directly, but no gradient/jacobian is built for testing ReverseDiff with the corresponding rrule. However this works fine for single-output functions together with ReverseDiff.gradient.
See the following MWE:
using ForwardDiff, Zygote, ReverseDiff, ChainRulesCore
# SINGLE OUTPUT FUNCTION
f(x) = sum(4x .+ 1)
function ChainRulesCore.rrule(::typeof(f), x)
r = f(x)
function back(d)
return ChainRulesCore.NoTangent(), fill(3, size(x))
end
return r, back
end
ReverseDiff.@grad_from_chainrules f(x::AbstractVector{<:ReverseDiff.TrackedReal})
seed = rand(3)
# Everything ok, ForwardDiff computes the correct derivatives (no frule defined),
# ReverseDiff and Zygote use the new rrule as to expect
ForwardDiff.gradient(f, seed)
Zygote.gradient(f, seed)[1]
ReverseDiff.gradient(f, seed)
# MULTI OUTPUT FUNCTION
f_multi(x, y) = (4x .+ 1, 3x .+ 1 .+ y)
function ChainRulesCore.rrule(::typeof(f_multi), x, y)
r = f_multi(x, y)
function back(d)
y1, y2 = d
return ChainRulesCore.NoTangent(), fill(2 , size(x)), fill(3 , size(y))
end
return r, back
end
ReverseDiff.@grad_from_chainrules f_multi(x::AbstractVector{<:ReverseDiff.TrackedReal}, y::AbstractVector{<:Real})
# ForwardDiff computes the correct derivatives (no frule defined),
# Zygote use the new rrule as to expect, ReverseDiff fails!
ForwardDiff.jacobian(x -> f_multi(x, ones(3))[1], seed)
Zygote.jacobian(x -> f_multi(x, ones(3))[1], seed)[1]
ReverseDiff.jacobian(x -> f_multi(x, ones(3))[1], seed) # this errors!
Tested in Julia 1.8.5, all used libraries up-to-date.
Thanks in advance & best regards!
Forgot to post the error message:
ERROR: MethodError: no method matching track(::Tuple{Vector{Float64}, Vector{Float64}}, ::Vector{ReverseDiff.AbstractInstruction})
Closest candidates are:
track(::AbstractArray, ::Vector{ReverseDiff.AbstractInstruction}) at ...\ReverseDiff.jl\src\tracked.jl:469
track(::Real, ::Vector{ReverseDiff.AbstractInstruction}) at ...\ReverseDiff.jl\src\tracked.jl:467
track(::typeof(vcat), ::Union{Number, AbstractVecOrMat}...) at ...\ReverseDiff.jl\src\macros.jl:190
...
Stacktrace:
[1] track(#unused#::typeof(f_multi), x::ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, y::Vector{Float64})
@ Main ...\ReverseDiff.jl\src\macros.jl:329
[2] f_multi(x::ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, y::Vector{Float64})
@ Main ...\ReverseDiff.jl\src\macros.jl:324
[3] (::var"#17#18")(x::ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}})
@ Main ...\MWE_multi_reversediff.jl:44
[4] ReverseDiff.JacobianTape(f::var"#17#18", input::Vector{Float64}, cfg::ReverseDiff.JacobianConfig{ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, Nothing})
@ ReverseDiff ...\ReverseDiff.jl\src\api\tape.jl:229
[5] jacobian(f::Function, input::Vector{Float64}, cfg::ReverseDiff.JacobianConfig{ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, Nothing}) (repeats 2 times)
@ ReverseDiff ...\src\api\jacobians.jl:23
[6] top-level scope
@ ...\MWE_multi_reversediff.jl:44
+1 --- I've run into the same problem and my MWE is almost identical to the one above.
For me this is a huge problem, because I am hoping to use RevDiff over Zygote to get second derivatives. But when you implement a pullback of a pullback then you will typically have multiple outputs to take care of.
If anybody can suggest how to fix this or work around it, I'd be very grateful.
CC @tjjarvinen