`@grad` or `ChainRulesCore.rrule` with compiled tape?
Is it possible to define a pullback via the ReverseDiff.@grad macro or a ChainRulesCore.rrule and use it with a compiled tape? When I try either approach with a simple function the pullback seems to remember the input value that was used when the tape was compiled. Is this expected?
Here's an example with @grad:
module ReverseDiffGradMacro
using ReverseDiff
using ForwardDiff
f1(x) = sin(x) + cos(x)
f1(x::ReverseDiff.TrackedReal) = ReverseDiff.track(f1, x)
ReverseDiff.@grad function f1(x::Real)
xv = ReverseDiff.value(x)
function f1_pullback(Δ)
println("hello from @grad f1_pullback")
return (Δ*(cos(xv) - sin(xv)),)
end
return f1(xv), f1_pullback
end
function doit()
f1vec = X->f1(X[1])
x = pi/6
println("calculating gradient of f1vec, x = $x")
dydx_hand = cos(x) - sin(x)
dydx_rd1_notape = ReverseDiff.gradient(f1vec, [x])
dydx_fd1 = ForwardDiff.gradient(f1vec, [x])
@show dydx_hand dydx_fd1 dydx_rd1_notape
println("calculating gradient of f1vec with compiled tape, x = $x")
tape = ReverseDiff.GradientTape(f1vec, [x])
compiled_tape = ReverseDiff.compile(tape)
dydx_rd1_compiled = ReverseDiff.gradient!(compiled_tape, [x])
@show dydx_rd1_compiled
x2 = x + 0.1
println("calculating gradient of f1vec with compiled tape, x = $x2")
dydx_hand2 = cos(x2) - sin(x2)
dydx_fd2 = ForwardDiff.gradient(f1vec, [x2])
dydx_rd2_compiled = ReverseDiff.gradient!(compiled_tape, [x2])
@show dydx_hand2 dydx_fd2 dydx_rd2_compiled
return nothing
end
end # module
And here's an example with an rrule from ChainRulesCore:
module ReverseDiffChainRulesCore
using ChainRulesCore: NoTangent
import ChainRulesCore: rrule
using ReverseDiff
using ForwardDiff
f1(x) = sin(x) + cos(x)
function rrule(::typeof(f1), x)
function f1_pullback(ybar)
println("hello from rrule f1_pullback")
return NoTangent(), ybar*(cos(x) - sin(x))
end
return f1(x), f1_pullback
end
ReverseDiff.@grad_from_chainrules f1(x::ReverseDiff.TrackedReal)
function doit()
f1vec = X->f1(X[1])
x = pi/6
println("calculating gradient of f1vec, x = $x")
dydx_hand = cos(x) - sin(x)
dydx_rd1_notape = ReverseDiff.gradient(f1vec, [x])
dydx_fd1 = ForwardDiff.gradient(f1vec, [x])
@show dydx_hand dydx_fd1 dydx_rd1_notape
println("calculating gradient of f1vec with compiled tape, x = $x")
tape = ReverseDiff.GradientTape(f1vec, [x])
compiled_tape = ReverseDiff.compile(tape)
dydx_rd1_compiled = ReverseDiff.gradient!(compiled_tape, [x])
@show dydx_rd1_compiled
x2 = x + 0.1
println("calculating gradient of f1vec with compiled tape, x = $x2")
dydx_hand2 = cos(x2) - sin(x2)
dydx_fd2 = ForwardDiff.gradient(f1vec, [x2])
dydx_rd2_compiled = ReverseDiff.gradient!(compiled_tape, [x2])
@show dydx_hand2 dydx_fd2 dydx_rd2_compiled
return nothing
end
end # module
In both examples, dydx_rd2_compiled is a derivative calculated using a compiled tape with input value x2. It should be the same as dydx_hand2 and dydx_fd2, which are found using the actual expression for the derivative and ForwardDiff, respectively. But dydx_rd2_compiled actually matches all the previous derivatives calculated with input value x, which was used when the tape was compiled.
Any ideas?
Thanks!