OrdinaryDiffEq.jl
OrdinaryDiffEq.jl copied to clipboard
Using Enzyme for Jacobians, Adding an option in OrdinaryDiffEq
Hey! Now that I do understand what a Jacobian is 😭(its coming full circle so nice to be here now finally). How do we move forward from here? I notice you guys using ForwardDiff for Jacobians currently in couple of places here.
I think we need to add an Enzyme option to calc_J
and calc_J!
when we are calculating the Jacobians in https://github.com/SciML/OrdinaryDiffEq.jl/blob/master/src/derivative_utils.jl
To calculate Jacobian using Enzyme.jl is simply:
ForwardMode:
function test(v)
[v[2], v[1]*v[1], v[1]*v[1]*v[1]]
end
jac = Enzyme.jacobian(Forward, test, [2.0, 3.0])
ReverseMode:
function test(v)
[v[2], v[1]*v[1], v[1]*v[1]*v[1]]
end
jac = Enzyme.jacobian(Reverse, test, [2.0, 3.0], Val(3))
This is very crude and likely wrong, is the change going to be like this or it will be using jacobian(uf, x, integrator)
which seems to be acting like a common interface if I understand correctly:
function calc_J(integrator, cache, next_step::Bool = false; use_enzyme= true)
@unpack dt, t, uprev, f, p, alg = integrator
if next_step
t = t + dt
uprev = integrator.u
end
if alg isa DAEAlgorithm
if DiffEqBase.has_jac(f)
J = f.jac(duprev, uprev, p, t)
elseif use_enzyme == true
@info "Enzyme.jacobian"
J = Enzyme.jacobian(Forward, f, uprev)
else
@unpack uf = cache
x = zero(uprev)
J = jacobian(uf, x, integrator)
end
else
if DiffEqBase.has_jac(f)
J = f.jac(uprev, p, t)
elseif use_enzyme
@info "Enzyme.jacobian"
J = Enzyme.jacobian(Forward, f, uprev)
else
@unpack uf = cache
uf.f = nlsolve_f(f, alg)
uf.p = p
uf.t = t
J = jacobian(uf, uprev, integrator)
end
integrator.stats.njacs += 1
if alg isa CompositeAlgorithm
integrator.eigen_est = constvalue(opnorm(J, Inf))
end
end
J
end
Also I am assuming uprev being the point like [2.0, 3.0]
you mentioned and hence using it with Enzyme.Jacobian. I think it might be u though rather than uprev. documentation seems to be bit sparse in the docstrings but this helps: https://docs.sciml.ai/DiffEqDocs/stable/basics/integrator/
You'd probably just want to use the same wrapped f
, i.e. uf
, just like the ForwardDiff call and add a new dispatch to jacobian
.