Zygote.jl
                                
                                 Zygote.jl copied to clipboard
                                
                                    Zygote.jl copied to clipboard
                            
                            
                            
                        calculating 2nd order differentials of a function containing a NN w.r.t. to parameters
Do you support any way to calculate the 2nd order differential of a function containing a Neural Network, w.r.t. the parameters?
For instance, consider a loss function defined as below.
using Flux: Chain, Dense, σ, crossentropy, params
using Zygote
model = Chain(
    x -> reshape(x, :, size(x, 4)),
    Dense(2, 5),
    Dense(5, 1),
    x -> σ.(x)
)
n_data = 5
input = randn(2, 1, 1, n_data)
target = randn(1, n_data)
loss = model -> crossentropy(model(input), target)
My goal is to obtain the hessian of loss(model) w.r.t. Flux.params(model). Ideally I would like to have a function like hessian(loss, model), which is currently not supported in Zygote.
In order to construct this, I tried some approaches combining Zygote.gradient, and Zygote.jacobian(added in #890)
- Of course, simply combining the two did not work
zygrad = model -> Zygote.gradient(loss, model)
zyjacob = model -> Zygote.jacobian(zygrad, model)
zyjacob(model)    # ERROR: ArgumentError: jacobian expected a function which returns an array, or a scalar, got Tuple{NamedTuple{(:layers,),Tuple{Tuple{Nothing,NamedTuple{(:W, :b, :σ),Tuple{Array{Float64,2},Array{Float64,1},Nothing}},NamedTuple{(:W, :b, :σ),Tuple{Array{Float64,2},Array{Float64,1},Nothing}},Nothing}}}}
# or more explicitly, obtaining a jacobian of one of the weight matrices
zygrad_w1 = model -> Zygote.gradient(loss, model)[1][1][2][1]
zyjacob_w1 = model -> Zygote.jacobian(zygrad, model)
zyjacob_w1(model)    # ERROR: Can't differentiate foreigncall expression
- ...or combining with implicit gradient
zygrad = θ -> Zygote.gradient(() -> loss(model), θ)[θ[1]]
zyjacob = θ -> Zygote.jacobian(() -> zygrad_implicit(θ), θ)
zyjacob(params(model))    # ERROR: Can't differentiate foreigncall expression
# or a different combination...
zyjacob = θ -> Zygote.jacobian(zygrad_implicit, θ)
zyjacob(params(model))    # ERROR: MethodError: no method matching (::var"#75#77")()
- So I just thought I would make my own jacobian, by element-wise 2nd order gradients, but that didn't work either
zygrad_p1 = model -> Zygote.gradient(loss, model)[1][1][2][1][1]    #Fist element
zygrad2 = model -> Zygote.gradient(zygrad_p1, model)
I'm still learning the framework and I will continue to do so, but as of now I've been working on this for days and exhausted my ideas for a workaround.
Calculating 2nd order differentials is essential in my field of work and I would like to know if Zygote will support it in the near future and whether there are any known solutions to this problem.
For reference, I've been using a hessian_wrt_all_params(func, model) till now implemented in pytorch.
I just wanted to implement it in julia.
I don't think there's a way to do this with Zygote's implicit parameter dictionary, params(model). Or at least, nobody has written one. But if I'm reading correctly, your Python code steps through the parameters but builds up the entire Hessian matrix, which isn't something it makes sense to store per-parameter, since it has many off-diagonal blocks. If you need this whole matrix, you can get it by doing something like this:
v, re = Flux.destructure(model)        # length(v) == sum(length, params(model))
loss(model,x,y) = sum(abs2.(model(x) .- y))
g = Zygote.gradient(v -> loss(re(v),x,y), v)[1]  # length(g) == length(v)
h = Zygote.hessian(v -> loss(re(v),x,y), v)      # size(h) == (length(v), length(v))
Ref https://github.com/FluxML/Zygote.jl/pull/823
@mcabbott Great! The entire Hessian matrix is exactly what I needed in the end.
The python code looks messy because all methods of gradients I could find calculated for each layer(resulting in blocks), and whenever I tried flattening the parameters first it would break. When I got it to work, I didn't dare to touch it.
I never found Flux.destructure() during my investigation. This changes everything. I will try it. Thank you! :)