ImplicitDifferentiation.jl
ImplicitDifferentiation.jl copied to clipboard
Prevent multiple calls to `forward` when using `ForwardDiff.jacobian`
When calculating sufficiently large Jacobians with ForwardDiff, the forward
mapping $x \rightarrow y(x)$ is called over and over again for each chunk that ForwardDiff.jacobian
computes.
Here is a minimal example:
using ImplicitDifferentiation, ForwardDiff
x = rand(100) # length(x) must be larger than the chunk size to see the effect
function sqrt_fixedpoint(x::AbstractArray)
y = ones(eltype(x), size(x))
for _ in 1:10
y .= 0.5 .* (y .+ x ./ y)
end
y
end
function forward(x)
println("forward called")
sqrt_fixedpoint(x)
end
conditions(x, y) = 0.5 .* (y .+ x ./ y) .- y
implicit = ImplicitFunction(forward, conditions)
ForwardDiff.jacobian(implicit, x)
Output:
forward called
forward called
forward called
forward called
forward called
forward called
forward called
forward called
forward called
100×100 Matrix{Float64}:
...
This behavior is expected, as forward differentiation should return also the value of the function (for which we need to call forward
) along with the partials. But for chunked evaluations of derivatives, this is redundant and may be a major performance issue if forward
is expensive.
As a workaround, you could precalculate the solution and short-circuit the forward pass:
function forward(x; y_solution=nothing)
if !isnothing(y_solution)
println("forward called with y_solution, no work required")
return y_solution
end
println("forward called, doing expensive computation")
sqrt_fixedpoint(x)
end
conditions(x, y; y_solution=nothing) = 0.5 .* (y .+ x ./ y) .- y
implicit = ImplicitFunction(forward, conditions)
y_solution = implicit(x)
ForwardDiff.jacobian(_x -> implicit(_x; y_solution), x)
Output:
forward called, doing expensive computation
forward called with y_solution, no work required
forward called with y_solution, no work required
forward called with y_solution, no work required
forward called with y_solution, no work required
forward called with y_solution, no work required
forward called with y_solution, no work required
forward called with y_solution, no work required
forward called with y_solution, no work required
forward called with y_solution, no work required
100×100 Matrix{Float64}:
...
It seems to me as if there is no sane way of solving this directly in the implementation of implicit
for dual numbers. You would need to have some caching mechanism for the function value, which goes against the general idea of forward mode autodiff.
Maybe it would be useful to add the workaround to the documentation, the sections "Advanced use cases" or "Tricks" seem to be good places.