Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

use ForwardDiff.jacobian in place of Zygote.forward_jacobian

Open vpuri3 opened this issue 1 year ago • 5 comments

Pursuant to https://github.com/FluxML/Zygote.jl/pull/1270

vpuri3 avatar Oct 26 '23 17:10 vpuri3

There are some enzyme related errors in NNlib integration tests but they seem unrelated to this PR.

vpuri3 avatar Oct 26 '23 21:10 vpuri3

@ToucheSir, LMK I need to add more tests.

here's a working MWE with Lux. This also resolves https://github.com/FluxML/Zygote.jl/issues/1348

with the change in this PR, this code is working:

using Random
using Lux, CUDA, LuxCUDA, ComponentArrays
using Zygote, ForwardDiff

CUDA.allowscalar(false)

#==========================#
function testhessian(
    NN::Lux.AbstractExplicitLayer,
    data::Tuple;
    device = cpu_device(),
)
    p, st = Lux.setup(Random.default_rng(), NN)

    st = Lux.testmode(st)
    p = ComponentArray(p)

    xdata, ydata = data |> device
    p, st = (p, st)     |> device

    function loss(optx)
        ypred, _ = NN(xdata, optx, st)

        sum(abs2, ydata - ypred)
    end

    g(p) = Zygote.gradient(loss, p)[1]
    H(p) = ForwardDiff.jacobian(g, p)

    Zygote.hessian(loss, p)
end
#==========================#
NN = Chain(Dense(1, 3), Dense(3, 1))

data = ntuple(_ -> rand(1, 10), 2)
device = Lux.gpu_device()

H = testhessian(NN, data; device)
julia> include("hess.jl")
10×10 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
  0.236781  -0.075257    -1.20583    0.31846   -0.101217    -1.62179   -0.713834    0.503548  -1.14138     1.98508
 -0.075257   0.0239192    0.383253  -0.101217   0.0321702    0.515458   0.0296168  -0.780695   0.362769   -0.630924
 -1.20583    0.383253     6.1408    -1.62179    0.515458     8.2591     0.474545   -2.56436    5.19194   -10.1092
  0.318461  -0.101217    -1.62179    0.514738  -0.163601    -2.62135   -2.09317     0.677249  -1.53511     3.20854
 -0.101217   0.0321702    0.515458  -0.163601   0.0519977    0.833151   0.0398333  -2.18309    0.487909   -1.01978
 -1.62179    0.515458     8.2591    -2.62135    0.833151    13.3494     0.638242   -3.44895    5.84984   -16.3398
 -0.713834   0.0296168    0.474545  -2.09317    0.0398333    0.638242   0.0366717  -0.198167   0.449183   -0.781213
  0.503548  -0.780695    -2.56436    0.677249  -2.18309     -3.44895   -0.198167    1.07086   -2.4273      4.22154
 -1.14138    0.362769     5.19194   -1.53511    0.487909     5.84984    0.449183   -2.4273     5.50193    -9.56889
  1.98508   -0.630924   -10.1092     3.20854   -1.01978    -16.3398    -0.781213    4.22154   -9.56889    20.0
(hess) pkg> st
Status `~/.julia/dev/GeometryLearning.jl/hess/Project.toml`
  [052768ef] CUDA v5.0.0
  [b0b7db55] ComponentArrays v0.15.4
  [f6369f11] ForwardDiff v0.10.36
  [b2108857] Lux v0.5.8
  [d0bbae9a] LuxCUDA v0.3.1
  [e88e6eb3] Zygote v0.6.67 `~/.julia/dev/Zygote`

vpuri3 avatar Oct 26 '23 21:10 vpuri3

Did this ever reach a conclusion? I'm in need of the ability to take the jacobian with respect to the inputs of a (Lux) model output and then optimize that object using gradient descent updates on the (Lux) model parameters. Something like the following

using Lux, CUDA, LuxCUDA, ComponentArrays
using Zygote #https://github.com/vpuri3/Zygote.jl/tree/fwd
using ForwardDiff
using LinearAlgebra

CUDA.allowscalar(false)
## Setup
L = 5
bs = 3
m = Chain(Dense(L, L), relu, Dense(L, L))
ps, st = Lux.setup(Random.default_rng(), m)
dev = Lux.gpu_device()
ps = ComponentArray(ps) |> dev
x = randn(Float32, L, bs) |> dev
y = randn(Float32, L, bs) |> dev
## Forward
function getpred(x, m, ps, st)
    function getpotential(x)
        return first(m(x, ps, st))
    end
    pred = reshape(diag(ForwardDiff.jacobian(getpotential, x)), size(x)...)
    return pred
end
pred = getpred(x, m, ps, st)
## Backward
function getgrads(x, y, m, ps, st)
    gs = Zygote.gradient(p -> mse(getpred(x, m, p, st), y), ps)
    return gs
end
gs = getgrads(x, y, m, ps, st) # returns (nothing,)

Or should I be looking towards JAX for this sort of thing? The use case is thermodynamics.

aksuhton avatar Feb 21 '24 16:02 aksuhton

That's a better question for the SciML/Lux help channels, not this issue tracker.

ToucheSir avatar Feb 21 '24 17:02 ToucheSir

This PR changes the implementation used internally for FwdDiff-over-Zygote. It didn't get much attention as it was a little unclear what this solves -- see requests above for tests which fail before the change.

Your example wants to do Zygote-over-ForwardDiff, which won't work, and would not be changed by this PR.

(Zygote has a rule for ForwardDiff.jacobian(f, x) which was probably a bad idea, and translates it to Fwd-over-Fwd. It should complain loudly when f closes over parameters, as it cannot work out the derivative with respect to f.)

mcabbott avatar Feb 21 '24 17:02 mcabbott