Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

calculating 2nd order differentials of a function containing a NN w.r.t. to parameters

Open stash-196 opened this issue 4 years ago • 4 comments

Do you support any way to calculate the 2nd order differential of a function containing a Neural Network, w.r.t. the parameters?

For instance, consider a loss function defined as below.

using Flux: Chain, Dense, σ, crossentropy, params
using Zygote
model = Chain(
    x -> reshape(x, :, size(x, 4)),
    Dense(2, 5),
    Dense(5, 1),
    x -> σ.(x)
)
n_data = 5
input = randn(2, 1, 1, n_data)
target = randn(1, n_data)
loss = model -> crossentropy(model(input), target)

My goal is to obtain the hessian of loss(model) w.r.t. Flux.params(model). Ideally I would like to have a function like hessian(loss, model), which is currently not supported in Zygote.

In order to construct this, I tried some approaches combining Zygote.gradient, and Zygote.jacobian(added in #890)

  1. Of course, simply combining the two did not work
zygrad = model -> Zygote.gradient(loss, model)
zyjacob = model -> Zygote.jacobian(zygrad, model)
zyjacob(model)    # ERROR: ArgumentError: jacobian expected a function which returns an array, or a scalar, got Tuple{NamedTuple{(:layers,),Tuple{Tuple{Nothing,NamedTuple{(:W, :b, :σ),Tuple{Array{Float64,2},Array{Float64,1},Nothing}},NamedTuple{(:W, :b, :σ),Tuple{Array{Float64,2},Array{Float64,1},Nothing}},Nothing}}}}

# or more explicitly, obtaining a jacobian of one of the weight matrices
zygrad_w1 = model -> Zygote.gradient(loss, model)[1][1][2][1]
zyjacob_w1 = model -> Zygote.jacobian(zygrad, model)
zyjacob_w1(model)    # ERROR: Can't differentiate foreigncall expression
  1. ...or combining with implicit gradient
zygrad = θ -> Zygote.gradient(() -> loss(model), θ)[θ[1]]
zyjacob = θ -> Zygote.jacobian(() -> zygrad_implicit(θ), θ)
zyjacob(params(model))    # ERROR: Can't differentiate foreigncall expression

# or a different combination...
zyjacob = θ -> Zygote.jacobian(zygrad_implicit, θ)
zyjacob(params(model))    # ERROR: MethodError: no method matching (::var"#75#77")()
  1. So I just thought I would make my own jacobian, by element-wise 2nd order gradients, but that didn't work either
zygrad_p1 = model -> Zygote.gradient(loss, model)[1][1][2][1][1]    #Fist element
zygrad2 = model -> Zygote.gradient(zygrad_p1, model)

I'm still learning the framework and I will continue to do so, but as of now I've been working on this for days and exhausted my ideas for a workaround.

Calculating 2nd order differentials is essential in my field of work and I would like to know if Zygote will support it in the near future and whether there are any known solutions to this problem.

stash-196 avatar Feb 25 '21 02:02 stash-196

For reference, I've been using a hessian_wrt_all_params(func, model) till now implemented in pytorch.

I just wanted to implement it in julia.

stash-196 avatar Feb 25 '21 05:02 stash-196

I don't think there's a way to do this with Zygote's implicit parameter dictionary, params(model). Or at least, nobody has written one. But if I'm reading correctly, your Python code steps through the parameters but builds up the entire Hessian matrix, which isn't something it makes sense to store per-parameter, since it has many off-diagonal blocks. If you need this whole matrix, you can get it by doing something like this:

v, re = Flux.destructure(model)        # length(v) == sum(length, params(model))
loss(model,x,y) = sum(abs2.(model(x) .- y))
g = Zygote.gradient(v -> loss(re(v),x,y), v)[1]  # length(g) == length(v)
h = Zygote.hessian(v -> loss(re(v),x,y), v)      # size(h) == (length(v), length(v))

mcabbott avatar May 20 '21 02:05 mcabbott

Ref https://github.com/FluxML/Zygote.jl/pull/823

DhairyaLGandhi avatar May 20 '21 05:05 DhairyaLGandhi

@mcabbott Great! The entire Hessian matrix is exactly what I needed in the end.

The python code looks messy because all methods of gradients I could find calculated for each layer(resulting in blocks), and whenever I tried flattening the parameters first it would break. When I got it to work, I didn't dare to touch it.

I never found Flux.destructure() during my investigation. This changes everything. I will try it. Thank you! :)

stash-196 avatar May 20 '21 05:05 stash-196