ReverseDiff.jl
ReverseDiff.jl copied to clipboard
The number of arguments limitation of the forward macro.
The docstring of ReverseDiff.@forward
says:
Currently, only
length(args) <= 2
is supported.
However, in a simple function, this macro looks working well even for >3 arguments:
julia> import ReverseDiff
julia> ReverseDiff.@forward f(x, y, z, w) = x + 2y + 3z + 4w
ReverseDiff.ForwardOptimize{##hidden_f}(#hidden_f)
julia> f(1.0, 2.0, 3.0, 4.0)
30.0
julia> ∇f = ReverseDiff.compile_gradient(x -> f(x[1], x[2], x[3], x[4]), zeros(4))
(::#301) (generic function with 1 method)
julia> ∇f(zeros(4), ones(4))
4-element Array{Float64,1}:
1.0
2.0
3.0
4.0
Can I expect this works always or is there any pitfall in the case?
I believe this is actually hitting the n-ary fallback, which does indeed still work, but isn't doing the whole of the differentiation in forward mode:
julia> import ReverseDiff
julia> ReverseDiff.@forward f(x, y, z, w) = x + 2y + 3z + 4w
ReverseDiff.ForwardOptimize{##hidden_f}(#hidden_f)
julia> cfg = ReverseDiff.GradientConfig(rand(4))
ReverseDiff.GradientConfig
# record operations to cfg.tape
julia> f(cfg.input...)
TrackedReal<5j3>(1.30302042715e-313, 0.0, FT7, ---)
# if `f` was getting differentiated completely in forward mode,
# there would only be single instruction in the tape
julia> cfg.tape
6-element RawTape:
1 => ScalarInstruction(*):
input: (2,
TrackedReal<8dC>(2.1641615666e-314, 0.0, FT7, 2, Hot))
output: TrackedReal<9CN>(4.328323133e-314, 0.0, FT7, ---)
cache: Base.RefValue{ForwardDiff.Partials{2,Float64}}(Partials(2.0,2.0))
2 => ScalarInstruction(*):
input: (3,
TrackedReal<GXL>(2.167760388e-314, 0.0, FT7, 3, Hot))
output: TrackedReal<2e3>(6.5032811646e-314, 0.0, FT7, ---)
cache: Base.RefValue{ForwardDiff.Partials{2,Float64}}(Partials(3.0,3.0))
3 => ScalarInstruction(*):
input: (4,
TrackedReal<Kzu>(3.29036e-318, 0.0, FT7, 4, Hot))
output: TrackedReal<2UN>(1.3161435e-317, 0.0, FT7, ---)
cache: Base.RefValue{ForwardDiff.Partials{2,Float64}}(Partials(4.0,4.0))
4 => ScalarInstruction(+):
input: (TrackedReal<4hC>(2.1972838303e-314, 0.0, FT7, 1, Hot),
TrackedReal<9CN>(4.328323133e-314, 0.0, FT7, ---))
output: TrackedReal<KtE>(6.5256069634e-314, 0.0, FT7, ---)
cache: Base.RefValue{ForwardDiff.Partials{2,Float64}}(Partials(1.0,1.0))
5 => ScalarInstruction(+):
input: (TrackedReal<KtE>(6.5256069634e-314, 0.0, FT7, ---),
TrackedReal<2e3>(6.5032811646e-314, 0.0, FT7, ---))
output: TrackedReal<40f>(1.3028888128e-313, 0.0, FT7, ---)
cache: Base.RefValue{ForwardDiff.Partials{2,Float64}}(Partials(1.0,1.0))
6 => ScalarInstruction(+):
input: (TrackedReal<40f>(1.3028888128e-313, 0.0, FT7, ---),
TrackedReal<2UN>(1.3161435e-317, 0.0, FT7, ---))
output: TrackedReal<5j3>(1.30302042715e-313, 0.0, FT7, ---)
cache: Base.RefValue{ForwardDiff.Partials{2,Float64}}(Partials(1.0,1.0))
This behavior is only a stop-gap until @forward
supports an arbitrary number of arguments, which should be the case eventually.
@jrevels Thank you for clarification. So, a function with >2 arguments is compiled using a mixture of forward and reverse mode differentiations, right? I don't know why this limitation exists but I'm happy to know it is eventually removed.
So, a function with >2 arguments is compiled using a mixture of forward and reverse mode differentiations, right?
Yup.
I don't know why this limitation exists but I'm happy to know it is eventually removed.
It's simply more difficult to write an n-ary version, so I started with the unary and binary.
I'm actually going to re-open this issue to track progress on n-ary @forward
support.
If it can help, I wrote a module that makes the Forward macro usable for functions with an arbitrary number of scalar arguments (10 in the example here). Is probably a good reason not to do this, or to do it this way? Code optimisation is not my biggest strength, so I would be glad if anyone has suggestions on how to improve it!
@everywhere module enrichReverseDiff
using ReverseDiff,StaticArrays,ForwardDiff,DiffResults
import StaticArrays:SVector
import ReverseDiff:scalar_forward_exec!,ScalarInstruction,istracked,deriv,unseed!,increment_deriv!,pull_value!,value,RefValue,record!,unary_scalar_forward_exec!,binary_scalar_forward_exec!,tape,track,valtype,ForwardOptimize,TrackedReal,scalar_reverse_exec!,GradientTape,_GradientTape,track!,GradientConfig
# For 10 args
names = (:(aa,bb,cc,dd,ee,ff,gg,hh,ii,jj).args...)
types = (:(AA,BB,CC,DD,EE,FF,GG,HH,II,JJ).args...)
for k in 3:10
n = names[1:k]
N = parse(string("(",join(string.(n),","),")"))
t = types[1:k]
T = parse(string("(",join(string.(t),","),")"))
arg = map((x,y)->:($x::$y),n,t)
Q = quote
@generated function (self::ForwardOptimize{F})($(arg...)) where {F,$(t...)}
# Identify and Count the number of args that are TrackedReal
ex = Expr(:meta,:inline)
nTracked = 0
idTracked = Array{Bool,1}($k)
for j in 1:$k
if $(T)[j]<:TrackedReal
nTracked+=1
idTracked[j]=true
else
idTracked[j]=false
end
end
n = $n
if nTracked==0
# fallback
ex = quote
$ex
self.f($(n...))
end
else
k = $k
v2e = (map(x->:(typeof($x.value)),$n[idTracked])...,
map(x->:(typeof($x)),$n[.!idTracked])...,
(x->:(typeof($x.deriv)))($n[findfirst(idTracked)]))
kk = sum(idTracked)
if nTracked==1
V = :(value($(n[idTracked][1])))
ex = quote
$ex
T=promote_type($(v2e...))
result = DiffResults.DiffResult(zero(T),zero(T))
nArray = Array{Any,1}(vcat($(n...)))
$((nArray,result,V,f)->
ForwardDiff.derivative!(result, x->f(setindex!(nArray,x,idTracked)...), V))(nArray,result,$V,self.f)
tp = tape($(n[idTracked][1]))
out = track(DiffResults.value(result), $(v2e[end]), tp)
dt = DiffResults.derivative(result)
cache = RefValue(SVector{$k,typeof(dt)}(setindex!(zeros($k),dt,$(find(idTracked)[1]))...))
$((tp,f,N,out,cache)->record!(tp, ScalarInstruction, f, N,out,cache))(tp,self.f,Tuple([$(n...)]), out, cache)
out
end
elseif nTracked<$k
V = map(x->:(value($x)),$n[idTracked])
V = :(SVector($(V...)))
ex = quote
$ex
T=promote_type($(v2e...))
result = DiffResults.GradientResult(StaticArrays.SVector(zeros(T,$kk)...))
nArray = Array{Any,1}(vcat($(n...)))
$((nArray,result,V,f)->
ForwardDiff.gradient!(result, x->f(setindex!(nArray,x,idTracked)...), V))(nArray,result,$V,self.f)
tp = tape($(n[idTracked]...))
out = track(DiffResults.value(result), $(v2e[end]), tp)
cache = RefValue(SVector(setindex!(zeros($k),DiffResults.gradient(result),$(find(idTracked)))...))
$((tp,f,N,out,cache)->record!(tp, ScalarInstruction, f, N,out,cache))(tp,self.f,Tuple([$(n...)]), out, cache)
out
end
else
V = map(x->:(value($x)),$names[1:$k])
V = :(SVector($(V...)))
ex = quote
$ex
T=promote_type($(v2e...))#,$(names[1].parameters[end]))
f = self.f
result = DiffResults.GradientResult(SVector(zeros(T,$kk)...))
$((result,V,f)->ForwardDiff.gradient!(result, x->f(x...), V))(result,$V,f)
tp = tape($(n...))
out = track(DiffResults.value(result), $(v2e[end]), tp)
cache = RefValue(DiffResults.gradient(result))
$((tp,f,N,out,cache)->record!(tp, ScalarInstruction, f, N,out,cache))(tp,f,Tuple([$(n...)]), out, cache)
out
end
end
end
:($ex)
end
end
eval(Q)
end
# a few ReverseDiff functions need to be adapted
@noinline function mult_scalar_forward_exec!{F,O,K}(f::F, output::O, input::NTuple{K,Any}, cache)
@. pull_value!(input)
idTracked = vcat((@. istracked(input))...)
if sum(idTracked)>1
v = map(value,input[idTracked])
result2 = DiffResults.GradientResult(SVector(
zeros(valtype(O),sum(idTracked))...))
nArray = Array{Any,1}(vcat(input...))
result2 = ForwardDiff.gradient!(result2, x->f(setindex!(nArray,x,idTracked)...), SVector(v...))
ReverseDiff.value!(output, DiffResults.value(result2))
cache[] = setindex!(zeros(K),DiffResults.gradient(result2),find(idTracked))
else
v =value(input[idTracked][1])
result1 = DiffResults.DiffResult(zero(valtype(O)), zero(valtype(O)))
nArray = Array{Any,1}(vcat(input...))
result1 = ForwardDiff.derivative!(result1, x->f(setindex!(nArray,x,idTracked)...), v)
ReverseDiff.value!(output, DiffResults.value(result1))
partial = DiffResults.derivative(result1)
cache[] = SVector{K,typeof(partial)}(setindex!(zeros(K),partial,find(idTracked)[1]))
end
nothing
end
@noinline function scalar_forward_exec!{F,I,O,C}(instruction::ScalarInstruction{F,I,O,C})
f = instruction.func
input = instruction.input
output = instruction.output
cache = instruction.cache
if istracked(input)
unary_scalar_forward_exec!(f, output, input, cache)
elseif length(input)==2
binary_scalar_forward_exec!(f, output, input, cache)
else
mult_scalar_forward_exec!(f, output, input, cache)
end
return nothing
end
@generated function sub_scalar_reverse_exec{T,U,V}(input::T,output::U,partials::V)
ex = Expr(:meta,:inline)
Un = Union{ReverseDiff.TrackedReal,ReverseDiff.TrackedArray}
if T<:Un
ex = :($ex;increment_deriv!(input, deriv(output) * partials))
elseif T<:Tuple
ex = :($ex;output_deriv = deriv(output))
for k in 1:length(T.parameters)
if T.parameters[k]<:Un
ex = ex.head==:(=) ?
:($ex;increment_deriv!(input[$k],output_deriv*partials[$k])) :
:($(ex.args...);increment_deriv!(input[$k],output_deriv*partials[$k]))
end
end
else
@show T
error("Unknown input")
end
ex = :($ex;unseed!(output))
:($ex)
end
@noinline function scalar_reverse_exec!{F,I,O,C}(instruction::ScalarInstruction{F,I,O,C})
f = instruction.func
input = instruction.input
output = instruction.output
partials = instruction.cache[]
sub_scalar_reverse_exec(input,output,partials)
unseed!(output)
return nothing
end
end