ReverseDiff.jl icon indicating copy to clipboard operation
ReverseDiff.jl copied to clipboard

The number of arguments limitation of the forward macro.

Open bicycle1885 opened this issue 8 years ago • 4 comments

The docstring of ReverseDiff.@forward says:

Currently, only length(args) <= 2 is supported.

However, in a simple function, this macro looks working well even for >3 arguments:

julia> import ReverseDiff

julia> ReverseDiff.@forward f(x, y, z, w) = x + 2y + 3z + 4w
ReverseDiff.ForwardOptimize{##hidden_f}(#hidden_f)

julia> f(1.0, 2.0, 3.0, 4.0)
30.0

julia> ∇f = ReverseDiff.compile_gradient(x -> f(x[1], x[2], x[3], x[4]), zeros(4))
(::#301) (generic function with 1 method)

julia> ∇f(zeros(4), ones(4))
4-element Array{Float64,1}:
 1.0
 2.0
 3.0
 4.0

Can I expect this works always or is there any pitfall in the case?

bicycle1885 avatar Feb 17 '17 06:02 bicycle1885

I believe this is actually hitting the n-ary fallback, which does indeed still work, but isn't doing the whole of the differentiation in forward mode:

julia> import ReverseDiff

julia> ReverseDiff.@forward f(x, y, z, w) = x + 2y + 3z + 4w
ReverseDiff.ForwardOptimize{##hidden_f}(#hidden_f)

julia> cfg = ReverseDiff.GradientConfig(rand(4))
ReverseDiff.GradientConfig

# record operations to cfg.tape
julia> f(cfg.input...)
TrackedReal<5j3>(1.30302042715e-313, 0.0, FT7, ---)

# if `f` was getting differentiated completely in forward mode,
# there would only be single instruction in the tape
julia> cfg.tape
6-element RawTape:
1 => ScalarInstruction(*):
  input:  (2,
           TrackedReal<8dC>(2.1641615666e-314, 0.0, FT7, 2, Hot))
  output: TrackedReal<9CN>(4.328323133e-314, 0.0, FT7, ---)
  cache:  Base.RefValue{ForwardDiff.Partials{2,Float64}}(Partials(2.0,2.0))
2 => ScalarInstruction(*):
  input:  (3,
           TrackedReal<GXL>(2.167760388e-314, 0.0, FT7, 3, Hot))
  output: TrackedReal<2e3>(6.5032811646e-314, 0.0, FT7, ---)
  cache:  Base.RefValue{ForwardDiff.Partials{2,Float64}}(Partials(3.0,3.0))
3 => ScalarInstruction(*):
  input:  (4,
           TrackedReal<Kzu>(3.29036e-318, 0.0, FT7, 4, Hot))
  output: TrackedReal<2UN>(1.3161435e-317, 0.0, FT7, ---)
  cache:  Base.RefValue{ForwardDiff.Partials{2,Float64}}(Partials(4.0,4.0))
4 => ScalarInstruction(+):
  input:  (TrackedReal<4hC>(2.1972838303e-314, 0.0, FT7, 1, Hot),
           TrackedReal<9CN>(4.328323133e-314, 0.0, FT7, ---))
  output: TrackedReal<KtE>(6.5256069634e-314, 0.0, FT7, ---)
  cache:  Base.RefValue{ForwardDiff.Partials{2,Float64}}(Partials(1.0,1.0))
5 => ScalarInstruction(+):
  input:  (TrackedReal<KtE>(6.5256069634e-314, 0.0, FT7, ---),
           TrackedReal<2e3>(6.5032811646e-314, 0.0, FT7, ---))
  output: TrackedReal<40f>(1.3028888128e-313, 0.0, FT7, ---)
  cache:  Base.RefValue{ForwardDiff.Partials{2,Float64}}(Partials(1.0,1.0))
6 => ScalarInstruction(+):
  input:  (TrackedReal<40f>(1.3028888128e-313, 0.0, FT7, ---),
           TrackedReal<2UN>(1.3161435e-317, 0.0, FT7, ---))
  output: TrackedReal<5j3>(1.30302042715e-313, 0.0, FT7, ---)
  cache:  Base.RefValue{ForwardDiff.Partials{2,Float64}}(Partials(1.0,1.0))

This behavior is only a stop-gap until @forward supports an arbitrary number of arguments, which should be the case eventually.

jrevels avatar Feb 20 '17 03:02 jrevels

@jrevels Thank you for clarification. So, a function with >2 arguments is compiled using a mixture of forward and reverse mode differentiations, right? I don't know why this limitation exists but I'm happy to know it is eventually removed.

bicycle1885 avatar Feb 23 '17 02:02 bicycle1885

So, a function with >2 arguments is compiled using a mixture of forward and reverse mode differentiations, right?

Yup.

I don't know why this limitation exists but I'm happy to know it is eventually removed.

It's simply more difficult to write an n-ary version, so I started with the unary and binary.

I'm actually going to re-open this issue to track progress on n-ary @forward support.

jrevels avatar Mar 09 '17 16:03 jrevels

If it can help, I wrote a module that makes the Forward macro usable for functions with an arbitrary number of scalar arguments (10 in the example here). Is probably a good reason not to do this, or to do it this way? Code optimisation is not my biggest strength, so I would be glad if anyone has suggestions on how to improve it!

@everywhere module enrichReverseDiff
using ReverseDiff,StaticArrays,ForwardDiff,DiffResults
import StaticArrays:SVector
import ReverseDiff:scalar_forward_exec!,ScalarInstruction,istracked,deriv,unseed!,increment_deriv!,pull_value!,value,RefValue,record!,unary_scalar_forward_exec!,binary_scalar_forward_exec!,tape,track,valtype,ForwardOptimize,TrackedReal,scalar_reverse_exec!,GradientTape,_GradientTape,track!,GradientConfig
# For 10 args
names = (:(aa,bb,cc,dd,ee,ff,gg,hh,ii,jj).args...)
types = (:(AA,BB,CC,DD,EE,FF,GG,HH,II,JJ).args...)

for k in 3:10 
	n = names[1:k]
	N = parse(string("(",join(string.(n),","),")"))
	t = types[1:k]
	T = parse(string("(",join(string.(t),","),")"))
	arg = map((x,y)->:($x::$y),n,t)
	Q = quote
		@generated function (self::ForwardOptimize{F})($(arg...)) where {F,$(t...)}
			# Identify and Count the number of args that are TrackedReal
			ex = Expr(:meta,:inline)
			nTracked = 0
			idTracked = Array{Bool,1}($k)
			for j in 1:$k
				if $(T)[j]<:TrackedReal
					nTracked+=1
					idTracked[j]=true
				else
					idTracked[j]=false
				end
			end
			
			n = $n
			if nTracked==0
				# fallback
				ex = quote
					$ex
					self.f($(n...))
				end
			else
			k = $k
			v2e = (map(x->:(typeof($x.value)),$n[idTracked])...,
			       map(x->:(typeof($x)),$n[.!idTracked])...,
			       (x->:(typeof($x.deriv)))($n[findfirst(idTracked)]))
			kk = sum(idTracked)
			if nTracked==1
				V = :(value($(n[idTracked][1])))
				ex = quote
					$ex
					T=promote_type($(v2e...))
					result = DiffResults.DiffResult(zero(T),zero(T))
					nArray = Array{Any,1}(vcat($(n...)))
					$((nArray,result,V,f)->
					  ForwardDiff.derivative!(result, x->f(setindex!(nArray,x,idTracked)...), V))(nArray,result,$V,self.f)
					tp = tape($(n[idTracked][1]))
					out = track(DiffResults.value(result), $(v2e[end]), tp)
					dt = DiffResults.derivative(result)
					cache = RefValue(SVector{$k,typeof(dt)}(setindex!(zeros($k),dt,$(find(idTracked)[1]))...))
					$((tp,f,N,out,cache)->record!(tp, ScalarInstruction, f, N,out,cache))(tp,self.f,Tuple([$(n...)]), out, cache)
					out
				end
			elseif nTracked<$k
				
				V = map(x->:(value($x)),$n[idTracked])
				V = :(SVector($(V...)))
				
				ex = quote
					$ex
					T=promote_type($(v2e...))
					result = DiffResults.GradientResult(StaticArrays.SVector(zeros(T,$kk)...))
					nArray = Array{Any,1}(vcat($(n...)))
					$((nArray,result,V,f)->
					  ForwardDiff.gradient!(result, x->f(setindex!(nArray,x,idTracked)...), V))(nArray,result,$V,self.f)
					tp = tape($(n[idTracked]...))
					out = track(DiffResults.value(result), $(v2e[end]), tp)
					cache = RefValue(SVector(setindex!(zeros($k),DiffResults.gradient(result),$(find(idTracked)))...))
					$((tp,f,N,out,cache)->record!(tp, ScalarInstruction, f, N,out,cache))(tp,self.f,Tuple([$(n...)]), out, cache)
					out
				end
			else
				V = map(x->:(value($x)),$names[1:$k])
				V = :(SVector($(V...)))
				ex = quote
					$ex
					T=promote_type($(v2e...))#,$(names[1].parameters[end]))
					f = self.f
					result = DiffResults.GradientResult(SVector(zeros(T,$kk)...))
					$((result,V,f)->ForwardDiff.gradient!(result, x->f(x...), V))(result,$V,f)
					tp = tape($(n...))
					out = track(DiffResults.value(result), $(v2e[end]), tp)
					cache = RefValue(DiffResults.gradient(result))
					$((tp,f,N,out,cache)->record!(tp, ScalarInstruction, f, N,out,cache))(tp,f,Tuple([$(n...)]), out, cache)
					out
				end
			end
		end
			:($ex)
		end	
	end
	eval(Q)
end
# a few ReverseDiff functions need to be adapted
@noinline function mult_scalar_forward_exec!{F,O,K}(f::F, output::O, input::NTuple{K,Any}, cache)
	@. pull_value!(input)
	idTracked = vcat((@. istracked(input))...)
	if sum(idTracked)>1
		v = map(value,input[idTracked])
		result2 = DiffResults.GradientResult(SVector(
			zeros(valtype(O),sum(idTracked))...))
		nArray = Array{Any,1}(vcat(input...))
		result2 = ForwardDiff.gradient!(result2, x->f(setindex!(nArray,x,idTracked)...), SVector(v...))
		ReverseDiff.value!(output, DiffResults.value(result2))
		cache[] = setindex!(zeros(K),DiffResults.gradient(result2),find(idTracked))
	else
		v =value(input[idTracked][1])
		result1 = DiffResults.DiffResult(zero(valtype(O)), zero(valtype(O)))
		nArray = Array{Any,1}(vcat(input...))
		result1 = ForwardDiff.derivative!(result1, x->f(setindex!(nArray,x,idTracked)...), v)
		ReverseDiff.value!(output, DiffResults.value(result1))
		partial = DiffResults.derivative(result1)
		cache[] = SVector{K,typeof(partial)}(setindex!(zeros(K),partial,find(idTracked)[1]))
	end
	nothing
end

@noinline function scalar_forward_exec!{F,I,O,C}(instruction::ScalarInstruction{F,I,O,C})
    f = instruction.func
    input = instruction.input
    output = instruction.output
    cache = instruction.cache

    if istracked(input)
        unary_scalar_forward_exec!(f, output, input, cache)
	elseif length(input)==2
        binary_scalar_forward_exec!(f, output, input, cache)
	else
	mult_scalar_forward_exec!(f, output, input, cache)
    end
    return nothing
end
@generated function sub_scalar_reverse_exec{T,U,V}(input::T,output::U,partials::V)
	ex = Expr(:meta,:inline)
	Un = Union{ReverseDiff.TrackedReal,ReverseDiff.TrackedArray}
	if T<:Un
		ex = :($ex;increment_deriv!(input, deriv(output) * partials))
	elseif T<:Tuple
		ex = :($ex;output_deriv = deriv(output))
		for k in 1:length(T.parameters)
			if T.parameters[k]<:Un
				ex = ex.head==:(=) ?
				:($ex;increment_deriv!(input[$k],output_deriv*partials[$k])) : 
				:($(ex.args...);increment_deriv!(input[$k],output_deriv*partials[$k]))
			end
		end
	else
		@show T
		error("Unknown input")
	end
	ex = :($ex;unseed!(output))
	:($ex)
end

@noinline function scalar_reverse_exec!{F,I,O,C}(instruction::ScalarInstruction{F,I,O,C})
    f = instruction.func
    input = instruction.input
    output = instruction.output
    partials = instruction.cache[]
    sub_scalar_reverse_exec(input,output,partials)
    unseed!(output)
    return nothing
end
end

vmoens avatar Feb 12 '18 06:02 vmoens