ImplicitDifferentiation.jl
ImplicitDifferentiation.jl copied to clipboard
Issue with second derivative
Hi, I have been trying to use ImplicitDifferentiation.jl for higher order jvp but I was not lucky.
I want to use it for BifurcationKit.jl, where I need 3rd order jvp
Basic example:
using ImplicitDifferentiation
using Optim
using Random
using Zygote
Random.seed!(63)
function dumb_identity(x)
f(y) = sum(abs2, y-x)
y0 = zero(x)
res = optimize(f, y0, LBFGS(); autodiff=:forward)
y = Optim.minimizer(res)
return y
end;
zero_gradient(x, y) = 2(y - x);
implicit = ImplicitFunction(dumb_identity, zero_gradient);
x = rand(3, 2)
h = rand(3, 2)
D(x,h) = Zygote.jacobian(t->implicit(x .+ t .* h), 0)[1]
D(x,h) # works
D2(x,h1,h2) = Zygote.jacobian(t->D(x .+ t .* h2,h1), 0)[1]
D2(x,h,h) # does not work
I also have this code with ForwardDiff. Not sure the problem is the same
using ForwardDiffChainRules, ForwardDiff
@ForwardDiff_frule (f::typeof(implicit))(x::AbstractMatrix{<:ForwardDiff.Dual})
D(x,h) = ForwardDiff.derivative(t->implicit(x .+ t .* h), 0)
D(x,h) # works
D2(x,h1,h2) = ForwardDiff.derivative(t->D(x .+ t .* h2,h1), 0)
D2(x,h,h) # does not work
what's the error?
for zygote:
julia> D2(x,h,h) # does not work
ERROR: Compiling Tuple{typeof(Optim.perform_linesearch!), Optim.LBFGSState{Matrix{Float64}, Vector{Matrix{Float64}}, Vector{Matrix{Float64}}, Float64, Matrix{Float64}}, LBFGS{Nothing, LineSearches.InitialStatic{Float64}, LineSearches.HagerZhang{Float64, Base.RefValue{Bool}}, Optim.var"#19#21"}, Optim.ManifoldObjective{OnceDifferentiable{Float64, Matrix{Float64}, Matrix{Float64}}}}: try/catch is not supported.
Refer to the Zygote documentation for fixes.
https://fluxml.ai/Zygote.jl/latest/limitations
for FoD:
julia> D2(x,h,h) # does not work
ERROR: MethodError: no method matching gmres(::LinearOperators.LinearOperator{ForwardDiff.Dual{ForwardDiff.Tag{var"#29#30"{Matrix{Float64}, Matrix{Float64}, Matrix{Float64}}, Int64}, Float64, 1}, Int64, ImplicitDifferentiation.var"#mul_A!#11"{ImplicitDifferentiation.var"#pushforward_A#9"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ForwardDiffChainRules.ForwardDiffRuleConfig, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{var"#29#30"{Matrix{Float64}, Matrix{Float64}, Matrix{Float64}}, Int64}, Float64, 1}}}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{var"#29#30"{Matrix{Float64}, Matrix{Float64}, Matrix{Float64}}, Int64}, Float64, 1}}}, Nothing, Nothing, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#29#30"{Matrix{Float64}, Matrix{Float64}, Matrix{Float64}}, Int64}, Float64, 1}}}, ::Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#29#30"{Matrix{Float64}, Matrix{Float64}, Matrix{Float64}}, Int64}, Float64, 1}})
it seems the second differential bypasses the machinery of ImplicitDifferentiation
Do you think it requires to implement the second differential in the implicit theorem or the algo should recursively call the first differential?
Do you think it requires to implement the second differential in the implicit theorem or the algo should recursively call the first differential?
I've been asking myself the same question for a project of mine. The problem is, ChainRulesCore only supports specification of first derivatives, so I'm not sure how to go about it in a way that's easy to use.
it seems the second differential bypasses the machinery of ImplicitDifferentiation
That's definitely worth investigating. Maybe we should define a rule on the pushforward / pullback?
I add the link to the discourse post
The problem is, ChainRulesCore only supports specification of first derivatives, so I'm not sure how to go about it in a way that's easy to use.
An example is here https://github.com/JuliaNonconvex/NonconvexUtils.jl/blob/main/src/custom.jl
Changing this line https://github.com/gdalle/ImplicitDifferentiation.jl/blob/1581d2e3b1b1ddf083f0d370c9f7b323aa98f610/src/implicit_function.jl#L64 to
y = implicit(x; kwargs...)
and using a linear solver that's compatible with ForwardDiff makes the FD case work:
linear_solver(A, x) = (Matrix(A) \ x, (solved = true,))
implicit = ImplicitFunction(dumb_identity, zero_gradient, linear_solver);
Zygote still fails but I suspect that's because Zygote.jacobian mutates and is tripping Zygote's second differentiation. I have run into this issue in the past and partially avoided it by defining a version of jacobian that's not mutating. Perhaps using AbstractDifferentiation's implementation would be better here. Anyways, I will leave the rest of it to you :)
Changing this line
Seems easy enough to fix in the source code, should we?
Seems easy enough to fix in the source code, should we?
Let's
Random thought: the fact that we need a dual-compatible linear solver means we still autodiff through the iterations of the solver. Is it possible to avoid that altogether for second-order?
In theory we could. That would require defining the linear solver as an implicit function and using ForwardDiffChainRules on that. We would need to think of a good way to make this work for direct and iterative linear solvers.
Do you know if the Python package has higher order implicit derivates? If not, this could be an interesting conference paper.
Do you know if the Python package has higher order implicit derivatives?
No it only focuses on first-order differentiation. But it outsources the actual autodiff to JAX, which may be better at second-order stuff than Zygote for example.
That would require defining the linear solver as an implicit function and using ForwardDiffChainRules on that.
My thought was actually to differentiate the implicit function theorem a second time, like this.
I believe what the derivations in this link are doing is going to be computationally identical to making our linear solver an implicit function and letting ImplicitDifferentiation automate the rest for us. They just defined the higher order rule manually but then it gets ugly with multiple inputs and outputs, Hessian rules for functions, etc. We have a nice abstraction which can be nested, let's take advantage of that.
what the derivations in this link are doing is going to be computationally identical to making our linear solver an implicit function
I believe you're right. Although I think the SciML solvers from https://github.com/SciML/LinearSolve.jl are already differentiable with a similar implicit machinery, it can't hurt to roll out our own simpler version
I recall you tried LinearSolve before but didn't end up using it because it was still not mature enough. Maybe we can revisit it or roll our own, shouldn't be too hard for simple solvers.
I don't think we need that many solvers so direct and GMRES should be enough.
I recall you tried LinearSolve before but didn't end up using it because it was still not mature enough
It should definitely be more mature now but my main beef is that it has a truckload of dependencies, which cannot be made conditional until Julia 1.9
https://github.com/SciML/LinearSolve.jl/blob/main/Project.toml
Now (in main), you can pass DirectLinearSolver() as the linear_solver in the implicit function to do second order differentiation. For the more general case of making the linear solver an implicit function, I think we should open another issue.
See #77
Closing in favor of #77