ForwardDiff.jl icon indicating copy to clipboard operation
ForwardDiff.jl copied to clipboard

Is this code a supported use? (single-pass value and derivative)

Open gerlero opened this issue 1 year ago • 6 comments

Considering these two methods that compute the value and first derivative of a scalar function in a single pass:

import ForwardDiff
import DiffResults

@inline function value_and_derivative(f, Y::Type, x::Real)
    diffresult = ForwardDiff.derivative!(DiffResults.DiffResult(zero(Y), zero(Y)), f, x)
    return DiffResults.value(diffresult), DiffResults.derivative(diffresult)
end

@inline function value_and_derivative(f, x::Real)
    T = typeof(ForwardDiff.Tag(f, typeof(x)))
    ydual = f(ForwardDiff.Dual{T}(x, one(x)))
    return ForwardDiff.value(T, ydual), ForwardDiff.extract_derivative(T, ydual)
end
  • The first method seems to be the officially recommended way to do this. However, (1) it introduces the DiffResults dependency just for this, and (2) it needlessly requires the caller to specify f's return type ahead of the call. Neither are dealbreakers, but IMO they add friction for something that looks like it shouldn't have it (intuition says that the value comes for free when computing a derivative with ForwardDiff)

  • The second method uses the Tag and Dual types as well as the extract_derivative function, which are not listed in the "Differentiation API" in the docs, so I'm not sure if they're considered part of the stable public API

Both methods run equally fast (and significantly faster than a naive two-pass implementation), so my question is: does the second method constitute a supported use of ForwardDiff's public API?

If such an use isn't supported (but maybe even if it is so—both method implementations appear too convoluted for a pretty common use case IMO), I'd like to suggest adding a value_and_derivative function with the second method to either this or one of the other packages in JuliaDiff (I'm willing to write a PR).

Related: #401, #391

EDITS: y-> ydual, ydual.value -> ForwardDiff.value(T, ydual), add another related issue

gerlero avatar Nov 24 '22 17:11 gerlero