Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

Jacobian in loss function

Open dfenn opened this issue 4 years ago • 3 comments

This seems similar to #820 and https://github.com/FluxML/Flux.jl/issues/1464, but the suggestions there don't seem to help in my case.

I'm trying to use the Jacobian of a specific layer's activations with respect to the network input in order to calculate a penalty in my loss function. I'm not sure if the problem is that my implementation is naive, or if what I'm trying to do just isn't supported. I'm using Flux 0.12.2 and Zygote 0.6.10 with Julia 1.6.1.

Running the following MWE

using Flux
using ForwardDiff
using Zygote
using ReverseDiff

numfeatures = 4

model  = Chain(
        Dense(numfeatures, numfeatures, sigmoid),
        )

ps = Flux.params(model)
opt = ADAM(0.1)

function loss(x) 
    modelOutput = model(x)
    jac = ForwardDiff.jacobian(model, x)[1]
    reg = sum(jac)
    return reg
end

train_data = rand(numfeatures, 2)

for epoch in 1:10
    local trainLoss
    gs = Flux.gradient(ps) do
        trainLoss = loss(train_data)
        return trainLoss
    end
    Flux.Optimise.update!(opt, ps, gs)
    @show trainLoss
end

I get a few different outcomes, depending on which version of jacobian I call. Calling the Zygote and ForwardDiff methods results in the error ERROR: LoadError: Mutating arrays is not supported. Calling the ReverseDiff version gives me ERROR: LoadError: Can't differentiate foreigncall expression.

Should it be possible to get a Jacobian inside the loss function like this? If not, is there a better way to do it?

Thanks in advance--I appreciate any insight you can give me.

dfenn avatar Apr 26 '21 00:04 dfenn

Zygote's jacobian function isn't Zygote-differentiable. There was an alternative sketched in https://github.com/FluxML/Zygote.jl/issues/865 which might be.

ForwardDiff's jacobian also won't directly be Zygote-differentiable. But you can use ForwardDiff twice, something like this:

julia> my_jacobian(x -> my_jacobian(x -> x.^3, x)[1], 1:3)[1]
3×9 Matrix{Float64}:
 6.0  0.0  0.0  0.0   0.0  0.0  0.0  0.0   0.0
 0.0  0.0  0.0  0.0  12.0  0.0  0.0  0.0   0.0
 0.0  0.0  0.0  0.0   0.0  0.0  0.0  0.0  18.0

julia> Zygote.jacobian(x -> Zygote.forwarddiff(z -> ForwardDiff.jacobian(x -> x.^3, z), x), 1:3)[1]
9×3 Matrix{Int64}:
 6   0   0
 0   0   0
 0   0   0
 0   0   0
 0  12   0
 0   0   0
 0   0   0
 0   0   0
 0   0  18

Something is slightly wrong there. Second derivatives are not really well-supported, but can sometimes be made to work; you will have to experiment a bit, and start small.

mcabbott avatar Apr 26 '21 21:04 mcabbott

Thanks for your helpful response. Working from your example, this seems to work

julia> Flux.gradient(x -> sum(Zygote.forwarddiff(z -> ForwardDiff.jacobian(model, z), x)), rand(5, 2))[1]
5×2 Matrix{Float64}:
 -0.00228453  -0.0116751
 -0.00126774  -0.00494717
  0.00138824   0.0168742
  0.0042229    0.0126724
 -0.00115428  -0.0195571

but modifying accordingly the loss function from the MWE above doesn't work.

function loss(x) 
    jac = Zygote.forwarddiff(z -> ForwardDiff.jacobian(model, z), x)
    reg = sum(jac)
    @show reg
    return reg
end

If I print out the gradients, it looks like they're missing:

julia> include("mwe.jl")
reg = 0.5245257543539097
gs[p] = nothing
gs[p] = nothing
reg = 0.5245257543539097
gs[p] = nothing
gs[p] = nothing

It's not clear to me why the first approach works, but the second doesn't.

dfenn avatar May 16 '21 05:05 dfenn

I think what you're seeing is what Zygote.forwarddiff warns about, when it says "Note that the function f will drop gradients for any closed-over values". That is, you write ForwardDiff.jacobian(model, z) to compute the jacobian of the model with respect to z, but what I think you later ask for is the effect of parameters within model on the result. These aren't ever tracked by ForwardDiff.jl, which only "perturbs" its explicit input. But I agree it's disturbing if they aren't tracked within a Zygote.gradient call.

Smaller example, using Zygote#master which after #968 applies forwarddiff automatically:

julia> using Zygote, ForwardDiff

julia> W = rand(2); X = rand(2);

julia> G = gradient(Params([W,X])) do
         sum(ForwardDiff.gradient(x -> dot(x,W)^3, X))
       end
Grads(...)

julia> G[X]
2-element Vector{Float64}:
 9.779807803787289
 5.941561578834241

julia> G[W] === nothing
true

This is still true if you change it to call forward-over-reverse instead, although perhaps there is some context which ought to be passed to the inner call here, to inform it that we care about W?

julia> G2 = gradient(Params([W,X])) do
         sum(Zygote.forwarddiff(X -> Zygote.gradient(x -> dot(x,W)^3, X)[1], X))
       end
Grads(...)

julia> G2[X]
2-element Vector{Float64}:
 9.779807803787289
 5.941561578834241

julia> G2[W] === nothing
true

mcabbott avatar May 16 '21 21:05 mcabbott