ImplicitDifferentiation.jl icon indicating copy to clipboard operation
ImplicitDifferentiation.jl copied to clipboard

Prevent multiple calls to `forward` when using `ForwardDiff.jacobian`

Open sfalmo opened this issue 1 year ago • 0 comments

When calculating sufficiently large Jacobians with ForwardDiff, the forward mapping $x \rightarrow y(x)$ is called over and over again for each chunk that ForwardDiff.jacobian computes. Here is a minimal example:

using ImplicitDifferentiation, ForwardDiff

x = rand(100)  # length(x) must be larger than the chunk size to see the effect

function sqrt_fixedpoint(x::AbstractArray)
    y = ones(eltype(x), size(x))
    for _ in 1:10
        y .= 0.5 .* (y .+ x ./ y)
    end
    y
end

function forward(x)
    println("forward called")
    sqrt_fixedpoint(x)
end

conditions(x, y) = 0.5 .* (y .+ x ./ y) .- y

implicit = ImplicitFunction(forward, conditions)

ForwardDiff.jacobian(implicit, x)

Output:

forward called
forward called
forward called
forward called
forward called
forward called
forward called
forward called
forward called

100×100 Matrix{Float64}:
...

This behavior is expected, as forward differentiation should return also the value of the function (for which we need to call forward) along with the partials. But for chunked evaluations of derivatives, this is redundant and may be a major performance issue if forward is expensive.

As a workaround, you could precalculate the solution and short-circuit the forward pass:

function forward(x; y_solution=nothing)
    if !isnothing(y_solution)
        println("forward called with y_solution, no work required")
        return y_solution
    end
    println("forward called, doing expensive computation")
    sqrt_fixedpoint(x)
end

conditions(x, y; y_solution=nothing) = 0.5 .* (y .+ x ./ y) .- y

implicit = ImplicitFunction(forward, conditions)

y_solution = implicit(x)
ForwardDiff.jacobian(_x -> implicit(_x; y_solution), x)

Output:

forward called, doing expensive computation
forward called with y_solution, no work required
forward called with y_solution, no work required
forward called with y_solution, no work required
forward called with y_solution, no work required
forward called with y_solution, no work required
forward called with y_solution, no work required
forward called with y_solution, no work required
forward called with y_solution, no work required
forward called with y_solution, no work required

100×100 Matrix{Float64}:
...

It seems to me as if there is no sane way of solving this directly in the implementation of implicit for dual numbers. You would need to have some caching mechanism for the function value, which goes against the general idea of forward mode autodiff.

Maybe it would be useful to add the workaround to the documentation, the sections "Advanced use cases" or "Tricks" seem to be good places.

sfalmo avatar Feb 15 '24 10:02 sfalmo