Zygote.jl
Zygote.jl copied to clipboard
use ForwardDiff.jacobian in place of Zygote.forward_jacobian
Pursuant to https://github.com/FluxML/Zygote.jl/pull/1270
There are some enzyme related errors in NNlib integration tests but they seem unrelated to this PR.
@ToucheSir, LMK I need to add more tests.
here's a working MWE with Lux. This also resolves https://github.com/FluxML/Zygote.jl/issues/1348
with the change in this PR, this code is working:
using Random
using Lux, CUDA, LuxCUDA, ComponentArrays
using Zygote, ForwardDiff
CUDA.allowscalar(false)
#==========================#
function testhessian(
NN::Lux.AbstractExplicitLayer,
data::Tuple;
device = cpu_device(),
)
p, st = Lux.setup(Random.default_rng(), NN)
st = Lux.testmode(st)
p = ComponentArray(p)
xdata, ydata = data |> device
p, st = (p, st) |> device
function loss(optx)
ypred, _ = NN(xdata, optx, st)
sum(abs2, ydata - ypred)
end
g(p) = Zygote.gradient(loss, p)[1]
H(p) = ForwardDiff.jacobian(g, p)
Zygote.hessian(loss, p)
end
#==========================#
NN = Chain(Dense(1, 3), Dense(3, 1))
data = ntuple(_ -> rand(1, 10), 2)
device = Lux.gpu_device()
H = testhessian(NN, data; device)
julia> include("hess.jl")
10×10 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
0.236781 -0.075257 -1.20583 0.31846 -0.101217 -1.62179 -0.713834 0.503548 -1.14138 1.98508
-0.075257 0.0239192 0.383253 -0.101217 0.0321702 0.515458 0.0296168 -0.780695 0.362769 -0.630924
-1.20583 0.383253 6.1408 -1.62179 0.515458 8.2591 0.474545 -2.56436 5.19194 -10.1092
0.318461 -0.101217 -1.62179 0.514738 -0.163601 -2.62135 -2.09317 0.677249 -1.53511 3.20854
-0.101217 0.0321702 0.515458 -0.163601 0.0519977 0.833151 0.0398333 -2.18309 0.487909 -1.01978
-1.62179 0.515458 8.2591 -2.62135 0.833151 13.3494 0.638242 -3.44895 5.84984 -16.3398
-0.713834 0.0296168 0.474545 -2.09317 0.0398333 0.638242 0.0366717 -0.198167 0.449183 -0.781213
0.503548 -0.780695 -2.56436 0.677249 -2.18309 -3.44895 -0.198167 1.07086 -2.4273 4.22154
-1.14138 0.362769 5.19194 -1.53511 0.487909 5.84984 0.449183 -2.4273 5.50193 -9.56889
1.98508 -0.630924 -10.1092 3.20854 -1.01978 -16.3398 -0.781213 4.22154 -9.56889 20.0
(hess) pkg> st
Status `~/.julia/dev/GeometryLearning.jl/hess/Project.toml`
[052768ef] CUDA v5.0.0
[b0b7db55] ComponentArrays v0.15.4
[f6369f11] ForwardDiff v0.10.36
[b2108857] Lux v0.5.8
[d0bbae9a] LuxCUDA v0.3.1
[e88e6eb3] Zygote v0.6.67 `~/.julia/dev/Zygote`
Did this ever reach a conclusion? I'm in need of the ability to take the jacobian with respect to the inputs of a (Lux) model output and then optimize that object using gradient descent updates on the (Lux) model parameters. Something like the following
using Lux, CUDA, LuxCUDA, ComponentArrays
using Zygote #https://github.com/vpuri3/Zygote.jl/tree/fwd
using ForwardDiff
using LinearAlgebra
CUDA.allowscalar(false)
## Setup
L = 5
bs = 3
m = Chain(Dense(L, L), relu, Dense(L, L))
ps, st = Lux.setup(Random.default_rng(), m)
dev = Lux.gpu_device()
ps = ComponentArray(ps) |> dev
x = randn(Float32, L, bs) |> dev
y = randn(Float32, L, bs) |> dev
## Forward
function getpred(x, m, ps, st)
function getpotential(x)
return first(m(x, ps, st))
end
pred = reshape(diag(ForwardDiff.jacobian(getpotential, x)), size(x)...)
return pred
end
pred = getpred(x, m, ps, st)
## Backward
function getgrads(x, y, m, ps, st)
gs = Zygote.gradient(p -> mse(getpred(x, m, p, st), y), ps)
return gs
end
gs = getgrads(x, y, m, ps, st) # returns (nothing,)
Or should I be looking towards JAX for this sort of thing? The use case is thermodynamics.
That's a better question for the SciML/Lux help channels, not this issue tracker.
This PR changes the implementation used internally for FwdDiff-over-Zygote. It didn't get much attention as it was a little unclear what this solves -- see requests above for tests which fail before the change.
Your example wants to do Zygote-over-ForwardDiff, which won't work, and would not be changed by this PR.
(Zygote has a rule for ForwardDiff.jacobian(f, x)
which was probably a bad idea, and translates it to Fwd-over-Fwd. It should complain loudly when f
closes over parameters, as it cannot work out the derivative with respect to f
.)