NLsolve.jl icon indicating copy to clipboard operation
NLsolve.jl copied to clipboard

Reverse differentiation through nlsolve

Open antoine-levitt opened this issue 6 years ago • 52 comments

OK, so this is pretty speculative as the reverse differentiation packages are not there yet, but let's dream for a moment. It would be awesome to be able to just use reverse-mode differentiation on code like

function G(α)
    F(x) = x - α
    x = nlsolve(F, zeros(length(α)))
    H(x) = sum(x)
    H(x)
end

and take the gradient of G wrt α. Of course, both F and H are examples, and can be arbitrary functions.

So how to get the gradient of G? One can of course forward diff through G, which is not too hard to support from the perspective of nlsolve (although I haven't tried). But that's pretty inefficient if α is high-dimensional. One can try reverse-diffing through G, but that's pretty heavy since this has to basically record all the iterations. A better idea is to exploit the mathematical structure of the problem, and in particular the relationship dx/dα = -(∂F/∂x)^-1 ∂F/∂α (differentiate F(x(α),α)=0 wrt α), assuming nlsolve is converged perfectly. Reverse-mode autodiff requires the user to compute (dx/dα)^T δx, which is -(∂F/∂α)^T (∂F/∂x)^-T δx. If the jacobian is not provided (Broyden or Anderson), this can be done by using an iterative solver such as GMRES, and where the individual matvecs with (∂F/∂x)^T are performed with reverse diff.

The action point for nlsolve here is to write a reverse ChainRule (https://github.com/JuliaDiff/ChainRules.jl) for nlsolve. This might be tricky because nlsolve takes a function as argument, but we might get by with just calling a diff function on F recursively. CC @jrevels to check this isn't a completely stupid idea. Of course, this isn't necessarily specific to nlsolve; the same ideas apply to optim (writing ∇F = 0) and diffeq (adjoint equations) for instance.

antoine-levitt avatar Jan 05 '19 10:01 antoine-levitt

So how to get the gradient of G? One can of course forward diff through G, which is not too hard to support from the perspective of nlsolve (although I haven't tried). But that's pretty inefficient if α is high-dimensional. One can try reverse-diffing through G, but that's pretty heavy since this has to basically record all the iterations. A better idea is to exploit the mathematical structure of the problem, and in particular the relationship dx/dα = -(∂F/∂x)^-1 ∂F/∂α (differentiate F(x(α),α)=0 wrt α), assuming nlsolve is converged perfectly. Reverse-mode autodiff requires the user to compute (dx/dα)^T δx, which is -(∂F/∂α)^T (∂F/∂x)^-T δx. If the jacobian is not provided (Broyden or Anderson), this can be done by using an iterative solver such as GMRES, and where the individual matvecs with (∂F/∂x)^T are performed with reverse diff.

You might want to take a look at https://arxiv.org/abs/1812.01892 . Yes, forward is fast but doesn't scale, and hard-coded adjoints do well. But I think the golden solution might be to just wait for source-to-source like Zygote.jl since then reverse mode can be done without operation tracking.

ChrisRackauckas avatar Jan 08 '19 20:01 ChrisRackauckas

Even then, Zygote still has to record all the history of the iterative method and then run it backwards, so that'll likely be slow and memory-consuming, won't it?

antoine-levitt avatar Jan 08 '19 20:01 antoine-levitt

I don't think it has to build a tape to handle loops? But then again, then I don't know how it know how many times to go back through the loop. @MikeInnes

ChrisRackauckas avatar Jan 08 '19 20:01 ChrisRackauckas

Well if you backward differentiate through a loop don't you have to keep track of all the intermediary steps?

I took at look at your (very nice btw) paper; I think one key difference between adjoints for solving equations and differential equations is that in diff eqs, you are interested in the whole solution, and so don't have any choice but to keep it in memory, at which point reverse-mode differentiating through the thing doesn't look so bad. When solving equations you just want the final solution and discard the iterates. This enables more efficient adjoints for nonlinear solves, where you can just discard the convergence history.

To make my point above clearer, take a simpler case: computing the gradient of y -> <b, A^-1 y>, where the A solve is done through a simple iterative method like CG (let's say). If you run reverse AD through CG, you're going to need to store all iterates (or use fancy techniques), and need potentially a lot of memory. Instead, you can just write the gradient as A^-T b and CG-solve that. Obviously this is a trivial example but it generalizes to arbitrary nonlinear systems and outputs.

antoine-levitt avatar Jan 08 '19 20:01 antoine-levitt

I see what you're saying and it totally makes sense in this case. I don't know the right place to overload for this though. We did it directly on Flux.Tracker types.

ChrisRackauckas avatar Jan 08 '19 20:01 ChrisRackauckas

This sounds very useful. How about starting from something very simple, like this?

using NLsolve
using Zygote
using Zygote: @adjoint, forward

@adjoint nlsolve(f, j, x0; kwargs...) =
    let result = nlsolve(f, j, x0; kwargs...)
        result, function(vresult)
            # This backpropagator returns (- v' (df/dx)⁻¹ (df/dp))'
            v = vresult[].zero
            x = result.zero
            J = j(x)
            _, back = forward(f -> f(x), f)
            return (back(-(J' \ v))[1], nothing, nothing)
        end
    end

It looks like it's working:

julia> d, = gradient(p -> nlsolve(x -> [x[1]^3 - p],
                                  x -> fill(3x[1]^2, (1, 1)),
                                  [1.0]).zero[1],
                     8.0)
(0.08333333333333333,)

julia> d ≈ 1/3 * 8.0^(1/3 - 1)
true

tkf avatar Aug 26 '19 08:08 tkf

Is Zygote ready to be used in the wild?

pkofod avatar Aug 26 '19 08:08 pkofod

Wow, that is so cool. Now we just need https://github.com/JuliaDiff/ChainRulesCore.jl/issues/22 to be able to take a dependency on ChainRulesCore, put that code in nlsolve, add the iterative solve for the zeroth-order methods, and we rule the world! (well, except for mutation...)

antoine-levitt avatar Aug 26 '19 08:08 antoine-levitt

(well, except for mutation...)

What do you mean?

pkofod avatar Aug 26 '19 08:08 pkofod

https://github.com/FluxML/Zygote.jl/pull/75

antoine-levitt avatar Aug 26 '19 08:08 antoine-levitt

@antoine-levitt FYI it looks like complex number interface could become another blocker for Zygote users (see https://github.com/FluxML/Zygote.jl/issues/142#issuecomment-525209133) although I guess we still can use other AD packages based on ChainRulesCore? But are there other AD packages closer to production-ready than Zygote.jl? I'm wondering if it makes sense to define just for Zygote.jl for now. Reading https://github.com/FluxML/Zygote.jl/pull/291 it seems that ChainRulesCore's API would be close to Zygote.jl so migration doesn't sound hard.

Of course, people can just define their own wrapper. I did it already (https://github.com/tkf/SteadyStateFit.jl/blob/239b18252ea5a596b780ddfbeb483b7e80b17572/src/znlsolve.jl) so this is not a blocker for me personally anymore.

tkf avatar Aug 28 '19 01:08 tkf

Don't think it's such a blocker : both zygote and chainrules support complex differentials in their full generality, it's just a question of putting the APIs together and of optimization.

It looks like zygote and chainrules are going to mesh in the short term, so we might as well wait until then. @oxinabox, does that sound reasonable? I think it's better for nlsolve to take on a dependency on ChainRulesCore than on Zygote.

antoine-levitt avatar Aug 28 '19 06:08 antoine-levitt

Sounds reasonable to me

pkofod avatar Aug 28 '19 06:08 pkofod

Don't think it's such a blocker : both zygote and chainrules support complex differentials in their full generality, it's just a question of putting the APIs together and of optimization.

Yeah, I don't think it will be much of a blocker.

It looks like zygote and chainrules are going to mesh in the short term, so we might as well wait until then. @oxinabox, does that sound reasonable? I think it's better for nlsolve to take on a dependency on ChainRulesCore than on Zygote.

No later than end of the year. Hopefully much sooner.

There is also ZygoteRules.jl which is I think a Zygote specific equiv of ChainRulesCore. One could use that in the short term, but I am hoping for it to either be retired, or marked as "use only if you need rules that depend on the internals if Zygote". ChainRules is more general. Most notably, it also supports the upcoming ForwardDiff2.

oxinabox avatar Aug 28 '19 07:08 oxinabox

OK, that's good news.

tkf avatar Aug 28 '19 22:08 tkf

Lyndon mentioned it, but just linking ZygoteRules explicitly.

RE using Zygote in the wild: the marker for that is really going to be when we release Flux + Zygote; once it's ready for that it's going to have been pretty heavily user-tested. OTOH it's still a relatively large dependency. My suggestion would be to use ZygoteRules to add this adjoint for now, and then switch to ChainRules once it's ready.

MikeInnes avatar Aug 29 '19 08:08 MikeInnes

So here's a (very crude) prototype of reverse diffing through a nonlinear PDE solve (based on @tkf's code, but with fully iterative methods, to get something representative of large-scale applications)

using NLsolve
using Zygote
using Zygote: @adjoint, forward
using IterativeSolvers
using LinearMaps
using SparseArrays

# nlsolve maps f to the solution x of f(x) = 0
# We have ∂x = -(df/dx)^-1 ∂f, and so the adjoint is df = -(df/dx)^-T dx
@adjoint nlsolve(f, x0; kwargs...) =
    let result = nlsolve(f, x0; kwargs...)
        result, function(vresult)
            dx = vresult[].zero
            x = result.zero
            _, back_x = forward(f, x)

            JT(df) = back_x(df)[1]
            # solve JT*df = -dx
            L = LinearMap(JT, length(x0))
            df = gmres(L,-dx)

            _, back_f = forward(f -> f(x), f)
            return (back_f(df)[1], nothing, nothing)
        end
    end

const N = 10000
const nonlin = 0.1
const A = spdiagm(0 => fill(10.0, N), 1 => fill(-1.0, N-1), -1 => fill(-1.0, N-1))
const p0 = randn(N)
f(x, p) = A*x + nonlin*x.^2 - p
solve_x(p) = nlsolve(x -> f(x, p), zeros(N), method=:anderson, m=10).zero
obj(p) = sum(solve_x(p))

Zygote.refresh()
g_auto, = gradient(obj, p0)
g_analytic = gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N))
display(g_auto)
display(g_analytic)

@btime gradient(obj, p0)
@btime gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N))

Performance is not great, essentially 20x compared to the analytic version. However profiling shows that this overhead is pretty localized, so it might be possible to optimize it away and get essentially the same perf as the analytic one (this should be a relatively easy case for reverse diff, since there's only vector operations, and no loop). I'm not quite sure what's going here; one possibility is that Zygote tries to diff wrt globally defined constants.

antoine-levitt avatar Aug 31 '19 11:08 antoine-levitt

You could try explicitly dropping gradients of globals to see if that's the issue.

MikeInnes avatar Sep 02 '19 13:09 MikeInnes

OK, but how do I do that?

antoine-levitt avatar Sep 02 '19 17:09 antoine-levitt

I see a similar hiccup with closures and would like to know any solution/workaround https://github.com/FluxML/Zygote.jl/issues/323

tkf avatar Sep 03 '19 00:09 tkf

@tkf I see you closed the issue there, but the discussion there was too technical for me to follow. Could you summarize what it means for the code above? Will it be fixed by the chainrules integration?

antoine-levitt avatar Sep 04 '19 15:09 antoine-levitt

@antoine-levitt Short answer is, IIUC, it'll be solved by switching to ChainRulesCore. I posted a longer answer with step-by-step code in https://github.com/FluxML/Zygote.jl/issues/323#issuecomment-528125263 explaining why I thought it was solved.

tkf avatar Sep 04 '19 23:09 tkf

So here's a (very crude) prototype of reverse diffing through a nonlinear PDE solve (based on @tkf's code, but with fully iterative methods, to get something representative of large-scale applications)

using NLsolve
using Zygote
using Zygote: @adjoint, forward
using IterativeSolvers
using LinearMaps
using SparseArrays

# nlsolve maps f to the solution x of f(x) = 0
# We have ∂x = -(df/dx)^-1 ∂f, and so the adjoint is df = -(df/dx)^-T dx
@adjoint nlsolve(f, x0; kwargs...) =
    let result = nlsolve(f, x0; kwargs...)
        result, function(vresult)
            dx = vresult[].zero
            x = result.zero
            _, back_x = forward(f, x)

            JT(df) = back_x(df)[1]
            # solve JT*df = -dx
            L = LinearMap(JT, length(x0))
            df = gmres(L,-dx)

            _, back_f = forward(f -> f(x), f)
            return (back_f(df)[1], nothing, nothing)
        end
    end

const N = 10000
const nonlin = 0.1
const A = spdiagm(0 => fill(10.0, N), 1 => fill(-1.0, N-1), -1 => fill(-1.0, N-1))
const p0 = randn(N)
f(x, p) = A*x + nonlin*x.^2 - p
solve_x(p) = nlsolve(x -> f(x, p), zeros(N), method=:anderson, m=10).zero
obj(p) = sum(solve_x(p))

Zygote.refresh()
g_auto, = gradient(obj, p0)
g_analytic = gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N))
display(g_auto)
display(g_analytic)

@btime gradient(obj, p0)
@btime gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N))

Performance is not great, essentially 20x compared to the analytic version. However profiling shows that this overhead is pretty localized, so it might be possible to optimize it away and get essentially the same perf as the analytic one (this should be a relatively easy case for reverse diff, since there's only vector operations, and no loop). I'm not quite sure what's going here; one possibility is that Zygote tries to diff wrt globally defined constants.

If I use LinearAlgebra.diagm in f(x,p), it will raise "Need an adjoint for constructor Pair". How can I write the adjoint method similar to the above adjoint? many thanks!

Yansf677 avatar Aug 10 '20 12:08 Yansf677

I don't know, you probably need to take it up with Zygote (or ChainRules). Also note that the above code was for an older version of Zygote, it needs updating (if anyone does so, please post the result and check whether the above-mentioned slowdown is still present!)

antoine-levitt avatar Aug 13 '20 08:08 antoine-levitt

Here's @antoine-levitt's example as of today

using NLsolve
using Zygote
using Zygote: @adjoint
using IterativeSolvers
using LinearMaps
using SparseArrays
using LinearAlgebra
using BenchmarkTools

# nlsolve maps f to the solution x of f(x) = 0
# We have ∂x = -(df/dx)^-1 ∂f, and so the adjoint is df = -(df/dx)^-T dx
@adjoint nlsolve(f, x0; kwargs...) =
    let result = nlsolve(f, x0; kwargs...)
        result, function(vresult)
            dx = vresult[].zero
            x = result.zero
            _, back_x = Zygote.pullback(f, x)

            JT(df) = back_x(df)[1]
            # solve JT*df = -dx
            L = LinearMap(JT, length(x0))
            df = gmres(L,-dx)

            _, back_f = Zygote.pullback(f -> f(x), f)
            return (back_f(df)[1], nothing, nothing)
        end
    end

const N = 10000
const nonlin = 0.1
const A = spdiagm(0 => fill(10.0, N), 1 => fill(-1.0, N-1), -1 => fill(-1.0, N-1))
const p0 = randn(N)
f(x, p) = A*x + nonlin*x.^2 - p
solve_x(p) = nlsolve(x -> f(x, p), zeros(N), method=:anderson, m=10).zero
obj(p) = sum(solve_x(p))

Zygote.refresh()
g_auto, = gradient(obj, p0)
g_analytic = gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N))
display(g_auto)
display(g_analytic)
@show sum(abs.(g_auto - g_analytic))

@btime gradient(obj, p0); 
@btime gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N));

My local timings:

@btime gradient(obj, p0);  # 2.823 s (1141 allocations: 5.99 GiB) 
@btime gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N));  # 21.230 ms (908 allocations: 24.03 MiB)
Status `~/.../Project.toml`
   [6e4b80f9] BenchmarkTools v0.5.0
   [42fd0dbc] IterativeSolvers v0.9.0
   [7a12625a] LinearMaps v3.2.0
   [2774e3e8] NLsolve v4.5.1
   [e88e6eb3] Zygote v0.6.3
   [2f01184e] SparseArrays

niklasschmitz avatar Feb 04 '21 19:02 niklasschmitz

Thanks @niklasschmitz for providing the updated code. The runtime for gradient is now about 100x that for gmres. @antoine-levitt reported 20x. Do you know what happened here? Also the number of memory allocations for Zygote are about the same but are in total about 250x that of GMRES. What is the cause of these large memory allocations?

rkube avatar Mar 11 '21 15:03 rkube

It now seems to me that the big slowdown is caused by the sparse matrix A somehow not being handled efficiently. Here's what I get for a dense matrix A (and choosing N=1000) by changing the above snippet:

- const A = spdiagm(0 => fill(10.0, N), 1 => fill(-1.0, N-1), -1 => fill(-1.0, N-1))
+ const A = Array(spdiagm(0 => fill(10.0, N), 1 => fill(-1.0, N-1), -1 => fill(-1.0, N-1))) # try dense A, for comparison only

For this I get the following timings:

@btime gradient(obj, p0); #   26.382 ms (624 allocations: 63.30 MiB) 
@btime gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N));  #   16.002 ms (446 allocations: 9.52 MiB)

So the previous 100x relative slowdown seems gone, cc @antoine-levitt @rkube

The large performance penalties when using SparseMatrixCSC for A might be worth its own Zygote issue, if it doesn't exist already.

niklasschmitz avatar Mar 25 '21 17:03 niklasschmitz

Could also be that the cost of matvecs with dense matrices is much larger than the sparse ones so that it hides the overhead?

antoine-levitt avatar Mar 25 '21 18:03 antoine-levitt

I now tried to double-check by trying the sparse case again but with a custom rrule for the inner function f:

using ChainRulesCore
function ChainRulesCore.rrule(::typeof(f), x, p)
    y = f(x, p)
    function f_pullback(ȳ)
        ∂x = @thunk(A'ȳ + 2nonlin*x.*ȳ)
        ∂p = @thunk(-ȳ)
        return (NO_FIELDS, ∂x, ∂p)
    end
    return y, f_pullback
end
Zygote.refresh()

Going back to the original example problem from above (i.e. N=10000 and A=spdiagm(...)) I now get

@btime gradient(obj, p0);  # 22.756 ms (986 allocations: 23.99 MiB) 
@btime gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N));  # 23.065 ms (786 allocations: 21.23 MiB)

niklasschmitz avatar Mar 26 '21 10:03 niklasschmitz

https://github.com/JuliaDiff/ChainRulesCore.jl/pull/363 is required to avoid the Zygote dependency.

ChrisRackauckas avatar Jun 09 '21 12:06 ChrisRackauckas