Example of JVP / J'VP with Krylov.jl
cc: @michel2323 & @amontoison
With @lcandiot I was wondering how to write a proper JVP J'VP with Enzyme and finally converged.
This might turn into a nice example one of these days.
@wsmoses any ideas on how to avoid the calls to zero(y) andzero(w)/copy(w)?
using Krylov, Enzyme, LinearOperators, ForwardDiff, LinearAlgebra
xk = ones(2)
F(x) = [x[1]^4 - 3; exp(x[2]) - 2; log(x[1]) - x[2]^2]
function JVP!(y, f::F, u, v) where F
Enzyme.autodiff(Forward,
(temp, v) -> (temp .= f(v); nothing),
Const,
DuplicatedNoNeed(zero(y), y),
DuplicatedNoNeed(u, v))
return nothing
end
"""
Calculate the Jacobian-Transpose Vector Product in-place by updating `y`.
"""
function JᵀVP!(y, f::F, u, w) where F
y .= 0 # Enzyme expects y to be zero
Enzyme.autodiff(Enzyme.Reverse,
(out, x) -> (out .= f(x); nothing),
Const,
DuplicatedNoNeed(zero(w), copy(w)), # copy since otherwise Enzyme will zero
DuplicatedNoNeed(u, y))
return nothing
end
J(y, v) = ForwardDiff.derivative!(y, t -> F(xk + t * v), 0)
Jᵀ(y, u) = ForwardDiff.gradient!(y, x -> dot(F(x), u), xk)
w = rand(3)
v = rand(2)
y_fwd = zeros(2)
Jᵀ(y_fwd, w)
@show y_fwd
y_enz = zeros(2)
@show JᵀVP!(y_enz, F, xk, w)
@show y_enz
@assert y_enz ≈ y_fwd
y2_fwd = zeros(3)
J(y2_fwd, v)
@show y2_fwd
y2_enz = zeros(3)
@show JVP!(y2_enz, F, xk, v)
@show y2_enz
@assert y2_enz ≈ y2_fwd
opJ_FWD = LinearOperator(Float64, 3, 2, false, false, (y, v) -> J(y, v),
(y, w) -> Jᵀ(y, w),
(y, u) -> Jᵀ(y, u))
x_forward, _ = lsmr(opJ_FWD, -F(xk))
opJ = LinearOperator(Float64, 3, 2, false, false, (y, v) -> JVP!(y, F, xk, v),
(y, w) -> JᵀVP!(y,F, xk, w),
(y, u) -> JᵀVP!(y,F, xk, u))
x_enzyme, _ = lsmr(opJ, -F(xk))
x_enzyme ≈ x_forward
depending on the array type, I think doing copyto! would let you do
function JVP!(y, f::F, u, v) where F
Enzyme.autodiff(Forward,
(temp, v) -> (temp .= f(v); nothing),
Const,
DuplicatedNoNeed(any undef thing,, y),
DuplicatedNoNeed(u, v))
return nothing
end
I think the same applies for zero(w), the copy(w) however is harder
For future reference:
# https://www.aanda.org/articles/aa/full_html/2016/02/aa27339-15/aa27339-15.html
function JVP_Finite_Diff(F,u,v)
λ = 10e-6
δ = λ * (λ + norm(u, Inf)/norm(v,Inf))
(F(u + δ .* v) - F(u)) ./ δ
end
"any undef thing" you mean a "Vector{Float64}(undef, 0)" would work?
I think so, in c we get away with passing a literal nullptr in these kinds of cases
Do you plan to add jvp and jtvp in the API on Enzyme.jl ?
Just to know if I should wait before adding an example in the documentation ok Krylov.jl.