ImplicitDifferentiation.jl icon indicating copy to clipboard operation
ImplicitDifferentiation.jl copied to clipboard

Issue with second derivative

Open rveltz opened this issue 2 years ago • 22 comments

Hi, I have been trying to use ImplicitDifferentiation.jl for higher order jvp but I was not lucky.

I want to use it for BifurcationKit.jl, where I need 3rd order jvp

Basic example:

using ImplicitDifferentiation
using Optim
using Random
using Zygote


Random.seed!(63)

function dumb_identity(x)
    f(y) = sum(abs2, y-x)
    y0 = zero(x)
    res = optimize(f, y0, LBFGS(); autodiff=:forward)
    y = Optim.minimizer(res)
    return y
end;

zero_gradient(x, y) = 2(y - x);
implicit = ImplicitFunction(dumb_identity, zero_gradient);
x = rand(3, 2)
h = rand(3, 2)
D(x,h) = Zygote.jacobian(t->implicit(x .+ t .* h), 0)[1]
D(x,h) # works
D2(x,h1,h2) = Zygote.jacobian(t->D(x .+ t .* h2,h1), 0)[1]
D2(x,h,h) # does not work

I also have this code with ForwardDiff. Not sure the problem is the same

using ForwardDiffChainRules, ForwardDiff
@ForwardDiff_frule (f::typeof(implicit))(x::AbstractMatrix{<:ForwardDiff.Dual})
D(x,h) = ForwardDiff.derivative(t->implicit(x .+ t .* h), 0)
D(x,h) # works
D2(x,h1,h2) = ForwardDiff.derivative(t->D(x .+ t .* h2,h1), 0)
D2(x,h,h) # does not work

rveltz avatar Feb 15 '23 13:02 rveltz

what's the error?

mohdibntarek avatar Feb 15 '23 16:02 mohdibntarek

for zygote:

julia> D2(x,h,h) # does not work
ERROR: Compiling Tuple{typeof(Optim.perform_linesearch!), Optim.LBFGSState{Matrix{Float64}, Vector{Matrix{Float64}}, Vector{Matrix{Float64}}, Float64, Matrix{Float64}}, LBFGS{Nothing, LineSearches.InitialStatic{Float64}, LineSearches.HagerZhang{Float64, Base.RefValue{Bool}}, Optim.var"#19#21"}, Optim.ManifoldObjective{OnceDifferentiable{Float64, Matrix{Float64}, Matrix{Float64}}}}: try/catch is not supported.
Refer to the Zygote documentation for fixes.
https://fluxml.ai/Zygote.jl/latest/limitations

for FoD:


julia> D2(x,h,h) # does not work
ERROR: MethodError: no method matching gmres(::LinearOperators.LinearOperator{ForwardDiff.Dual{ForwardDiff.Tag{var"#29#30"{Matrix{Float64}, Matrix{Float64}, Matrix{Float64}}, Int64}, Float64, 1}, Int64, ImplicitDifferentiation.var"#mul_A!#11"{ImplicitDifferentiation.var"#pushforward_A#9"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ForwardDiffChainRules.ForwardDiffRuleConfig, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{var"#29#30"{Matrix{Float64}, Matrix{Float64}, Matrix{Float64}}, Int64}, Float64, 1}}}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{var"#29#30"{Matrix{Float64}, Matrix{Float64}, Matrix{Float64}}, Int64}, Float64, 1}}}, Nothing, Nothing, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#29#30"{Matrix{Float64}, Matrix{Float64}, Matrix{Float64}}, Int64}, Float64, 1}}}, ::Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#29#30"{Matrix{Float64}, Matrix{Float64}, Matrix{Float64}}, Int64}, Float64, 1}})

it seems the second differential bypasses the machinery of ImplicitDifferentiation

rveltz avatar Feb 15 '23 20:02 rveltz

Do you think it requires to implement the second differential in the implicit theorem or the algo should recursively call the first differential?

rveltz avatar Mar 01 '23 07:03 rveltz

Do you think it requires to implement the second differential in the implicit theorem or the algo should recursively call the first differential?

I've been asking myself the same question for a project of mine. The problem is, ChainRulesCore only supports specification of first derivatives, so I'm not sure how to go about it in a way that's easy to use.

gdalle avatar Mar 01 '23 08:03 gdalle

it seems the second differential bypasses the machinery of ImplicitDifferentiation

That's definitely worth investigating. Maybe we should define a rule on the pushforward / pullback?

gdalle avatar Mar 01 '23 08:03 gdalle

I add the link to the discourse post

rveltz avatar Mar 01 '23 08:03 rveltz

The problem is, ChainRulesCore only supports specification of first derivatives, so I'm not sure how to go about it in a way that's easy to use.

An example is here https://github.com/JuliaNonconvex/NonconvexUtils.jl/blob/main/src/custom.jl

mohdibntarek avatar Mar 01 '23 09:03 mohdibntarek

Changing this line https://github.com/gdalle/ImplicitDifferentiation.jl/blob/1581d2e3b1b1ddf083f0d370c9f7b323aa98f610/src/implicit_function.jl#L64 to

y = implicit(x; kwargs...)

and using a linear solver that's compatible with ForwardDiff makes the FD case work:

linear_solver(A, x) = (Matrix(A) \ x, (solved = true,))
implicit = ImplicitFunction(dumb_identity, zero_gradient, linear_solver);

Zygote still fails but I suspect that's because Zygote.jacobian mutates and is tripping Zygote's second differentiation. I have run into this issue in the past and partially avoided it by defining a version of jacobian that's not mutating. Perhaps using AbstractDifferentiation's implementation would be better here. Anyways, I will leave the rest of it to you :)

mohdibntarek avatar Mar 01 '23 14:03 mohdibntarek

Changing this line

Seems easy enough to fix in the source code, should we?

gdalle avatar Mar 01 '23 16:03 gdalle

Seems easy enough to fix in the source code, should we?

Let's

mohdibntarek avatar Mar 01 '23 16:03 mohdibntarek

Random thought: the fact that we need a dual-compatible linear solver means we still autodiff through the iterations of the solver. Is it possible to avoid that altogether for second-order?

gdalle avatar Mar 10 '23 10:03 gdalle

In theory we could. That would require defining the linear solver as an implicit function and using ForwardDiffChainRules on that. We would need to think of a good way to make this work for direct and iterative linear solvers.

mohdibntarek avatar Mar 10 '23 10:03 mohdibntarek

Do you know if the Python package has higher order implicit derivates? If not, this could be an interesting conference paper.

mohdibntarek avatar Mar 10 '23 11:03 mohdibntarek

Do you know if the Python package has higher order implicit derivatives?

No it only focuses on first-order differentiation. But it outsources the actual autodiff to JAX, which may be better at second-order stuff than Zygote for example.

gdalle avatar Mar 10 '23 15:03 gdalle

That would require defining the linear solver as an implicit function and using ForwardDiffChainRules on that.

My thought was actually to differentiate the implicit function theorem a second time, like this.

gdalle avatar Mar 10 '23 15:03 gdalle

I believe what the derivations in this link are doing is going to be computationally identical to making our linear solver an implicit function and letting ImplicitDifferentiation automate the rest for us. They just defined the higher order rule manually but then it gets ugly with multiple inputs and outputs, Hessian rules for functions, etc. We have a nice abstraction which can be nested, let's take advantage of that.

mohdibntarek avatar Mar 10 '23 17:03 mohdibntarek

what the derivations in this link are doing is going to be computationally identical to making our linear solver an implicit function

I believe you're right. Although I think the SciML solvers from https://github.com/SciML/LinearSolve.jl are already differentiable with a similar implicit machinery, it can't hurt to roll out our own simpler version

gdalle avatar Mar 10 '23 18:03 gdalle

I recall you tried LinearSolve before but didn't end up using it because it was still not mature enough. Maybe we can revisit it or roll our own, shouldn't be too hard for simple solvers.

mohdibntarek avatar Mar 10 '23 19:03 mohdibntarek

I don't think we need that many solvers so direct and GMRES should be enough.

mohdibntarek avatar Mar 10 '23 19:03 mohdibntarek

I recall you tried LinearSolve before but didn't end up using it because it was still not mature enough

It should definitely be more mature now but my main beef is that it has a truckload of dependencies, which cannot be made conditional until Julia 1.9

https://github.com/SciML/LinearSolve.jl/blob/main/Project.toml

gdalle avatar Mar 10 '23 19:03 gdalle

Now (in main), you can pass DirectLinearSolver() as the linear_solver in the implicit function to do second order differentiation. For the more general case of making the linear solver an implicit function, I think we should open another issue.

mohdibntarek avatar Jul 30 '23 13:07 mohdibntarek

See #77

gdalle avatar Jul 30 '23 13:07 gdalle

Closing in favor of #77

gdalle avatar Feb 21 '24 13:02 gdalle