ReverseDiff.jl icon indicating copy to clipboard operation
ReverseDiff.jl copied to clipboard

`@grad` or `ChainRulesCore.rrule` with compiled tape?

Open dingraha opened this issue 4 years ago • 0 comments

Is it possible to define a pullback via the ReverseDiff.@grad macro or a ChainRulesCore.rrule and use it with a compiled tape? When I try either approach with a simple function the pullback seems to remember the input value that was used when the tape was compiled. Is this expected?

Here's an example with @grad:

module ReverseDiffGradMacro

using ReverseDiff
using ForwardDiff

f1(x) = sin(x) + cos(x)

f1(x::ReverseDiff.TrackedReal) = ReverseDiff.track(f1, x)
ReverseDiff.@grad function f1(x::Real)
    xv = ReverseDiff.value(x)

    function f1_pullback(Δ)
        println("hello from @grad f1_pullback")
        return (Δ*(cos(xv) - sin(xv)),)
    end

    return f1(xv), f1_pullback
end

function doit()
    f1vec = X->f1(X[1])
    x = pi/6
    
    println("calculating gradient of f1vec, x = $x")
    dydx_hand = cos(x) - sin(x)
    dydx_rd1_notape = ReverseDiff.gradient(f1vec, [x])
    dydx_fd1 = ForwardDiff.gradient(f1vec, [x])
    @show dydx_hand dydx_fd1 dydx_rd1_notape
    
    println("calculating gradient of f1vec with compiled tape, x = $x")
    tape = ReverseDiff.GradientTape(f1vec, [x])
    compiled_tape = ReverseDiff.compile(tape)
    dydx_rd1_compiled = ReverseDiff.gradient!(compiled_tape, [x])
    @show dydx_rd1_compiled

    x2 = x + 0.1
    println("calculating gradient of f1vec with compiled tape, x = $x2")
    dydx_hand2 = cos(x2) - sin(x2)
    dydx_fd2 = ForwardDiff.gradient(f1vec, [x2])
    dydx_rd2_compiled = ReverseDiff.gradient!(compiled_tape, [x2])
    @show dydx_hand2 dydx_fd2 dydx_rd2_compiled

    return nothing
end

end # module

And here's an example with an rrule from ChainRulesCore:

module ReverseDiffChainRulesCore

using ChainRulesCore: NoTangent
import ChainRulesCore: rrule

using ReverseDiff
using ForwardDiff

f1(x) = sin(x) + cos(x)

function rrule(::typeof(f1), x)

    function f1_pullback(ybar)
        println("hello from rrule f1_pullback")
        return NoTangent(), ybar*(cos(x) - sin(x))
    end

    return f1(x), f1_pullback
end

ReverseDiff.@grad_from_chainrules f1(x::ReverseDiff.TrackedReal)

function doit()
    f1vec = X->f1(X[1])
    x = pi/6
    
    println("calculating gradient of f1vec, x = $x")
    dydx_hand = cos(x) - sin(x)
    dydx_rd1_notape = ReverseDiff.gradient(f1vec, [x])
    dydx_fd1 = ForwardDiff.gradient(f1vec, [x])
    @show dydx_hand dydx_fd1 dydx_rd1_notape
    
    println("calculating gradient of f1vec with compiled tape, x = $x")
    tape = ReverseDiff.GradientTape(f1vec, [x])
    compiled_tape = ReverseDiff.compile(tape)
    dydx_rd1_compiled = ReverseDiff.gradient!(compiled_tape, [x])
    @show dydx_rd1_compiled

    x2 = x + 0.1
    println("calculating gradient of f1vec with compiled tape, x = $x2")
    dydx_hand2 = cos(x2) - sin(x2)
    dydx_fd2 = ForwardDiff.gradient(f1vec, [x2])
    dydx_rd2_compiled = ReverseDiff.gradient!(compiled_tape, [x2])
    @show dydx_hand2 dydx_fd2 dydx_rd2_compiled

    return nothing
end

end # module

In both examples, dydx_rd2_compiled is a derivative calculated using a compiled tape with input value x2. It should be the same as dydx_hand2 and dydx_fd2, which are found using the actual expression for the derivative and ForwardDiff, respectively. But dydx_rd2_compiled actually matches all the previous derivatives calculated with input value x, which was used when the tape was compiled.

Any ideas?

Thanks!

dingraha avatar Nov 21 '21 21:11 dingraha