Wrong direct forward differentiation of an ODE
I want to forward-differentiate an ODE, for example using this code
using LinearAlgebra
using SparseArrays
using OrdinaryDiffEq
using Enzyme
const T = ComplexF64
const A1 = sparse(T[0.0 1.0; 0.0 0.0])
const A2 = sparse(T[0.0 0.0; 1.0 0.0])
function dudt!(du, u, p, t)
mul!(du, A1, u, -p[1], zero(T))
mul!(du, A2, u, p[2], one(T))
return nothing
end
function my_fun2(p)
x = T[3.0, 4.0]
prob = ODEProblem{true}(dudt!, x, (0.0, 1.0), p)
sol = solve(prob, Tsit5(), save_everystep=false)
return real(sol.u[end][end])
end
# %%
p = [1.0, 2.0]
my_fun2(p) - my_fun2(p .+ 1)
dp = Enzyme.make_zero(p)
dp[1] = 1
Enzyme.autodiff(Enzyme.set_runtime_activity(Enzyme.Forward), my_fun2, Duplicated(p, dp)) # (0.0,)
dp
However dp is not updated and the gradient returns 0, which is wrong.
This is a bit large to dive into @ChrisRackauckas can you help reduce?
@albertomercurio ideally you first check that the behavior of your ode function is correct.
https://docs.sciml.ai/SciMLSensitivity/stable/faq/nothing/SciMLSensitivity/stable/faq/#How-do-I-isolate-potential-gradient-issues-and-improve-performance?
Has some tips how to extract that function.
However dp is not updated
Why should dp be updated when using Forward mode?
Hi @vchuravy. SciMLSensitivity.jl should not have anything to do with this error. I have already opened an issue there (https://github.com/SciML/SciMLSensitivity.jl/issues/1226), but forward differentiation does not call SciMLSensitivity.
Reverse differentiation works instead.
Love that SciML links are broken... Anyways please read the first entry here https://docs.sciml.ai/SciMLSensitivity/stable/faq/
Remember that in forward mode derivative calculations happens in the direction of computation. Thus dp will never be updated since your code is not updating p.
So you are calculating the forward sensitivy.
Constant memory is stored (or returned) to a differentiable variable.
As a result, Enzyme cannot provably ensure correctness and throws this error.
This might be due to the use of a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Runtime-Activity).
If Enzyme should be able to prove this use non-differentable, open an issue!
To work around this issue, either:
a) rewrite this variable to not be conditionally active (fastest, but requires a code change), or
b) set the Enzyme mode to turn on runtime activity (e.g. autodiff(set_runtime_activity(Reverse), ...) ). This will maintain correctness, but may slightly reduce performance.
Mismatched activity for: store i64 %36, i64 addrspace(11)* %47, align 8, !dbg !323, !tbaa !330, !alias.scope !255, !noalias !258 const val: %36 = ptrtoint {}* %35 to i64, !dbg !324
Type tree: {[-1]:Pointer}
llvalue= %36 = ptrtoint {}* %35 to i64, !dbg !324
Stacktrace:
[1] FunctionWrapper
@ ~/.julia/packages/FunctionWrappers/Q5cBx/src/FunctionWrappers.jl:107
[2] #1
@ ~/.julia/packages/FunctionWrappersWrappers/9XR0m/src/FunctionWrappersWrappers.jl:28
[3] map
@ ./tuple.jl:382
[4] FunctionWrappersWrapper (repeats 2 times)
@ ~/.julia/packages/FunctionWrappersWrappers/9XR0m/src/FunctionWrappersWrappers.jl:27
[5] wrapfun_iip
@ ~/.julia/packages/DiffEqBase/4uSBa/src/norecompile.jl:26
[6] promote_f
@ ~/.julia/packages/DiffEqBase/4uSBa/src/solve.jl:858
[7] #get_concrete_problem#51
@ ~/.julia/packages/DiffEqBase/4uSBa/src/solve.jl:767
SciMLSensitivity isn't even loaded in this example. This is Enzyme of the solver, direct forward mode differentiation, and it is bypassing the SciMLSensitivity adjoint pieces. The only rules that would be involved would be:
https://github.com/SciML/OrdinaryDiffEq.jl/blob/master/lib/OrdinaryDiffEqCore/ext/OrdinaryDiffEqCoreEnzymeCoreExt.jl
@ ~/.julia/packages/FunctionWrappers/Q5cBx/src/FunctionWrappers.jl:107
Likely what could help here is:
function my_fun2(p)
x = T[3.0, 4.0]
prob = ODEProblem{true, SciMLBase.FullSpecialize}(dudt!, x, (0.0, 1.0), p)
sol = solve(prob, Tsit5(), save_everystep=false)
return real(sol.u[end][end])
end
to turn off the FunctionWrapper-ing.
I just setup some tests on this in OrdinaryDiffEq https://github.com/SciML/OrdinaryDiffEq.jl/pull/2872
using Enzyme, OrdinaryDiffEqTsit5, StaticArrays, DiffEqBase, ForwardDiff, Test
function lorenz!(du, u, p, t)
du[1] = 10.0(u[2] - u[1])
du[2] = u[1] * (28.0 - u[3]) - u[2]
du[3] = u[1] * u[2] - (8 / 3) * u[3]
end
const _saveat = SA[0.0,0.25,0.5,0.75,1.0,1.25,1.5,1.75,2.0,2.25,2.5,2.75,3.0]
function f_dt(y::Array{Float64}, u0::Array{Float64})
tspan = (0.0, 3.0)
prob = ODEProblem{true, SciMLBase.FullSpecialize}(lorenz!, u0, tspan)
sol = DiffEqBase.solve(prob, Tsit5(), saveat = _saveat, sensealg = DiffEqBase.SensitivityADPassThrough(), abstol=1e-12, reltol=1e-12)
y .= sol[1,:]
return nothing
end;
function f_dt(u0)
tspan = (0.0, 3.0)
prob = ODEProblem{true, SciMLBase.FullSpecialize}(lorenz!, u0, tspan)
sol = DiffEqBase.solve(prob, Tsit5(), saveat = _saveat, sensealg = DiffEqBase.SensitivityADPassThrough(), abstol=1e-12, reltol=1e-12)
sol[1,:]
end;
u0 = [1.0; 0.0; 0.0]
fdj = ForwardDiff.jacobian(f_dt, u0)
ezj = stack(map(1:3) do i
d_u0 = zeros(3)
dy = zeros(13)
y = zeros(13)
d_u0[i] = 1.0
Enzyme.autodiff(Forward, f_dt, Duplicated(y, dy), Duplicated(u0, d_u0));
dy
end)
@test ezj ≈ fdj
function f_dt2(u0)
tspan = (0.0, 3.0)
prob = ODEProblem{true, SciMLBase.FullSpecialize}(lorenz!, u0, tspan)
sol = DiffEqBase.solve(prob, Tsit5(), dt=0.1, saveat = _saveat, sensealg = DiffEqBase.SensitivityADPassThrough(), abstol=1e-12, reltol=1e-12)
sum(sol[1,:])
end
fdg = ForwardDiff.gradient(f_dt2, u0)
d_u0 = zeros(3)
Enzyme.autodiff(Reverse, f_dt2, Active, Duplicated(u0, d_u0));
@test d_u0 ≈ fdg
This all works fine.
using LinearAlgebra
using SparseArrays
using OrdinaryDiffEq
using Enzyme
const T = ComplexF64
const A1 = sparse(T[0.0 1.0; 0.0 0.0])
const A2 = sparse(T[0.0 0.0; 1.0 0.0])
function dudt!(du, u, p, t)
mul!(du, A1, u, -p[1], zero(T))
mul!(du, A2, u, p[2], one(T))
return nothing
end
function my_fun2(p)
x = T[3.0, 4.0]
prob = ODEProblem{true, SciMLBase.FullSpecialize}(dudt!, x, (0.0, 1.0), p)
sol = solve(prob, Tsit5(), save_everystep=false)
return real(sol.u[end][end])
end
p = [1.0, 2.0]
dp = Enzyme.make_zero(p)
dp[1] = 1
Enzyme.autodiff(Enzyme.set_runtime_activity(Enzyme.Forward), my_fun2, Duplicated(p, dp))[1] # 3.1442798803e-314
The MWE without function wrappers is still a zero. This is with double checking that the other MWEs work. So there does definitely seem to be something wrong here... maybe something with differentiation of sparse matrix operations?
SciMLSensitivity isn't even loaded in this example.
My point was that SciMLSensitivies has a nice FAQ on this topic and that one should first check that the derivatives of the function alone are correct.
I see. Yeah this needed to isolate a bit, I thought it could just be another regression of the direct differentiation of the solvers, but from above it is something that maybe can be isolated to just f.
I have followed that guide, setting
u0 = T[3.0, 4.0]
du = similar(u0)
prob = ODEProblem{true}(dudt!, u0, (0.0, 1.0), p)
f! = prob.f
d_f! = Enzyme.make_zero(f!)
d_u0 = Enzyme.make_zero(u0)
d_du = Enzyme.make_zero(du)
d_p = Enzyme.make_zero(p)
d_t = Enzyme.make_zero(0.0)
d_u0[1] = 1.0
d_du[1] = 1.0
d_p[1] = 1.0
Enzyme.autodiff(Enzyme.set_runtime_activity(Enzyme.Forward), Duplicated(f!, d_f!), Enzyme.Const, Duplicated(u0, d_u0), Duplicated(du, d_du), Duplicated(p, d_p), Enzyme.Const(0.1))
But I still get zero everywhere. The same seems to happen for dense matrices, with the additional thing that this check fails as runtime activity has not been set for forward diff of BLAS.
@albertomercurio I'm confused what you're expecting. Forward mode propagates derivatives from input to output. In your orignial code [lmk if I should look at something else], p is not modified, as a result dp is correctly not modified.
The derivative should propagate from p to du. IIRC, this is giving a zero on the du even though the parameter has a derivative 1 and it's tied via mul!(du, A1, u, -p[1], zero(T)); mul!(du, A2, u, p[2], one(T)). I think it's not propagating the previous du derivative through the second mul! 5-arg form, i.e. it's not calculating the deriative of A B α + C β but instead A B α possibly.