Enzyme.jl icon indicating copy to clipboard operation
Enzyme.jl copied to clipboard

Wrong direct forward differentiation of an ODE

Open albertomercurio opened this issue 4 months ago • 13 comments

I want to forward-differentiate an ODE, for example using this code

using LinearAlgebra
using SparseArrays
using OrdinaryDiffEq
using Enzyme
const T = ComplexF64

const A1 = sparse(T[0.0 1.0; 0.0 0.0])
const A2 = sparse(T[0.0 0.0; 1.0 0.0])

function dudt!(du, u, p, t)
    mul!(du, A1, u, -p[1], zero(T))
    mul!(du, A2, u, p[2], one(T))
    return nothing
end

function my_fun2(p)
    x = T[3.0, 4.0]
    prob = ODEProblem{true}(dudt!, x, (0.0, 1.0), p)
    sol = solve(prob, Tsit5(), save_everystep=false)
    return real(sol.u[end][end])
end

# %%

p = [1.0, 2.0]

my_fun2(p) - my_fun2(p .+ 1)

dp = Enzyme.make_zero(p)
dp[1] = 1
Enzyme.autodiff(Enzyme.set_runtime_activity(Enzyme.Forward), my_fun2, Duplicated(p, dp)) # (0.0,)
dp

However dp is not updated and the gradient returns 0, which is wrong.

albertomercurio avatar Aug 20 '25 17:08 albertomercurio

This is a bit large to dive into @ChrisRackauckas can you help reduce?

wsmoses avatar Sep 11 '25 05:09 wsmoses

@albertomercurio ideally you first check that the behavior of your ode function is correct.

https://docs.sciml.ai/SciMLSensitivity/stable/faq/nothing/SciMLSensitivity/stable/faq/#How-do-I-isolate-potential-gradient-issues-and-improve-performance?

Has some tips how to extract that function.

However dp is not updated

Why should dp be updated when using Forward mode?

vchuravy avatar Sep 11 '25 06:09 vchuravy

Hi @vchuravy. SciMLSensitivity.jl should not have anything to do with this error. I have already opened an issue there (https://github.com/SciML/SciMLSensitivity.jl/issues/1226), but forward differentiation does not call SciMLSensitivity.

Reverse differentiation works instead.

albertomercurio avatar Sep 11 '25 07:09 albertomercurio

Love that SciML links are broken... Anyways please read the first entry here https://docs.sciml.ai/SciMLSensitivity/stable/faq/

Remember that in forward mode derivative calculations happens in the direction of computation. Thus dp will never be updated since your code is not updating p.

So you are calculating the forward sensitivy.

vchuravy avatar Sep 11 '25 08:09 vchuravy

Constant memory is stored (or returned) to a differentiable variable.

As a result, Enzyme cannot provably ensure correctness and throws this error.

This might be due to the use of a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Runtime-Activity).

If Enzyme should be able to prove this use non-differentable, open an issue!

To work around this issue, either:

 a) rewrite this variable to not be conditionally active (fastest, but requires a code change), or

 b) set the Enzyme mode to turn on runtime activity (e.g. autodiff(set_runtime_activity(Reverse), ...) ). This will maintain correctness, but may slightly reduce performance.

Mismatched activity for:   store i64 %36, i64 addrspace(11)* %47, align 8, !dbg !323, !tbaa !330, !alias.scope !255, !noalias !258 const val:   %36 = ptrtoint {}* %35 to i64, !dbg !324

Type tree: {[-1]:Pointer}

 llvalue=  %36 = ptrtoint {}* %35 to i64, !dbg !324


Stacktrace:

 [1] FunctionWrapper
   @ ~/.julia/packages/FunctionWrappers/Q5cBx/src/FunctionWrappers.jl:107
 [2] #1
   @ ~/.julia/packages/FunctionWrappersWrappers/9XR0m/src/FunctionWrappersWrappers.jl:28
 [3] map
   @ ./tuple.jl:382
 [4] FunctionWrappersWrapper (repeats 2 times)
   @ ~/.julia/packages/FunctionWrappersWrappers/9XR0m/src/FunctionWrappersWrappers.jl:27
 [5] wrapfun_iip
   @ ~/.julia/packages/DiffEqBase/4uSBa/src/norecompile.jl:26
 [6] promote_f
   @ ~/.julia/packages/DiffEqBase/4uSBa/src/solve.jl:858
 [7] #get_concrete_problem#51
   @ ~/.julia/packages/DiffEqBase/4uSBa/src/solve.jl:767

vchuravy avatar Sep 11 '25 09:09 vchuravy

SciMLSensitivity isn't even loaded in this example. This is Enzyme of the solver, direct forward mode differentiation, and it is bypassing the SciMLSensitivity adjoint pieces. The only rules that would be involved would be:

https://github.com/SciML/OrdinaryDiffEq.jl/blob/master/lib/OrdinaryDiffEqCore/ext/OrdinaryDiffEqCoreEnzymeCoreExt.jl

@ ~/.julia/packages/FunctionWrappers/Q5cBx/src/FunctionWrappers.jl:107

Likely what could help here is:

function my_fun2(p)
    x = T[3.0, 4.0]
    prob = ODEProblem{true, SciMLBase.FullSpecialize}(dudt!, x, (0.0, 1.0), p)
    sol = solve(prob, Tsit5(), save_everystep=false)
    return real(sol.u[end][end])
end

to turn off the FunctionWrapper-ing.

ChrisRackauckas avatar Sep 11 '25 09:09 ChrisRackauckas

I just setup some tests on this in OrdinaryDiffEq https://github.com/SciML/OrdinaryDiffEq.jl/pull/2872

using Enzyme, OrdinaryDiffEqTsit5, StaticArrays, DiffEqBase, ForwardDiff, Test

function lorenz!(du, u, p, t)
    du[1] = 10.0(u[2] - u[1])
    du[2] = u[1] * (28.0 - u[3]) - u[2]
    du[3] = u[1] * u[2] - (8 / 3) * u[3]
end

const _saveat =  SA[0.0,0.25,0.5,0.75,1.0,1.25,1.5,1.75,2.0,2.25,2.5,2.75,3.0]

function f_dt(y::Array{Float64}, u0::Array{Float64})
    tspan = (0.0, 3.0)
    prob = ODEProblem{true, SciMLBase.FullSpecialize}(lorenz!, u0, tspan)
    sol = DiffEqBase.solve(prob, Tsit5(), saveat = _saveat, sensealg = DiffEqBase.SensitivityADPassThrough(), abstol=1e-12, reltol=1e-12)
    y .= sol[1,:]
    return nothing
end;

function f_dt(u0)
    tspan = (0.0, 3.0)
    prob = ODEProblem{true, SciMLBase.FullSpecialize}(lorenz!, u0, tspan)
    sol = DiffEqBase.solve(prob, Tsit5(), saveat = _saveat, sensealg = DiffEqBase.SensitivityADPassThrough(), abstol=1e-12, reltol=1e-12)
    sol[1,:]
end;

u0 = [1.0; 0.0; 0.0]
fdj = ForwardDiff.jacobian(f_dt, u0)

ezj = stack(map(1:3) do i
    d_u0 = zeros(3)
    dy = zeros(13)
    y  = zeros(13)
    d_u0[i] = 1.0
    Enzyme.autodiff(Forward, f_dt,  Duplicated(y, dy), Duplicated(u0, d_u0));
    dy
end)

@test ezj ≈ fdj

function f_dt2(u0)
    tspan = (0.0, 3.0)
    prob = ODEProblem{true, SciMLBase.FullSpecialize}(lorenz!, u0, tspan)
    sol = DiffEqBase.solve(prob, Tsit5(), dt=0.1, saveat = _saveat, sensealg = DiffEqBase.SensitivityADPassThrough(), abstol=1e-12, reltol=1e-12)
    sum(sol[1,:])
end

fdg = ForwardDiff.gradient(f_dt2, u0)
d_u0 = zeros(3)
Enzyme.autodiff(Reverse, f_dt2,  Active, Duplicated(u0, d_u0));

@test d_u0 ≈ fdg

This all works fine.

ChrisRackauckas avatar Sep 11 '25 11:09 ChrisRackauckas

using LinearAlgebra
using SparseArrays
using OrdinaryDiffEq
using Enzyme
const T = ComplexF64

const A1 = sparse(T[0.0 1.0; 0.0 0.0])
const A2 = sparse(T[0.0 0.0; 1.0 0.0])

function dudt!(du, u, p, t)
    mul!(du, A1, u, -p[1], zero(T))
    mul!(du, A2, u, p[2], one(T))
    return nothing
end

function my_fun2(p)
    x = T[3.0, 4.0]
    prob = ODEProblem{true, SciMLBase.FullSpecialize}(dudt!, x, (0.0, 1.0), p)
    sol = solve(prob, Tsit5(), save_everystep=false)
    return real(sol.u[end][end])
end

p = [1.0, 2.0]
dp = Enzyme.make_zero(p)
dp[1] = 1
Enzyme.autodiff(Enzyme.set_runtime_activity(Enzyme.Forward), my_fun2, Duplicated(p, dp))[1] # 3.1442798803e-314

The MWE without function wrappers is still a zero. This is with double checking that the other MWEs work. So there does definitely seem to be something wrong here... maybe something with differentiation of sparse matrix operations?

ChrisRackauckas avatar Sep 11 '25 12:09 ChrisRackauckas

SciMLSensitivity isn't even loaded in this example.

My point was that SciMLSensitivies has a nice FAQ on this topic and that one should first check that the derivatives of the function alone are correct.

vchuravy avatar Sep 11 '25 12:09 vchuravy

I see. Yeah this needed to isolate a bit, I thought it could just be another regression of the direct differentiation of the solvers, but from above it is something that maybe can be isolated to just f.

ChrisRackauckas avatar Sep 11 '25 14:09 ChrisRackauckas

I have followed that guide, setting

u0 = T[3.0, 4.0]
du = similar(u0)

prob = ODEProblem{true}(dudt!, u0, (0.0, 1.0), p)

f! = prob.f

d_f! = Enzyme.make_zero(f!)
d_u0 = Enzyme.make_zero(u0)
d_du = Enzyme.make_zero(du)
d_p = Enzyme.make_zero(p)
d_t = Enzyme.make_zero(0.0)

d_u0[1] = 1.0
d_du[1] = 1.0
d_p[1] = 1.0

Enzyme.autodiff(Enzyme.set_runtime_activity(Enzyme.Forward), Duplicated(f!, d_f!), Enzyme.Const, Duplicated(u0, d_u0), Duplicated(du, d_du), Duplicated(p, d_p), Enzyme.Const(0.1))

But I still get zero everywhere. The same seems to happen for dense matrices, with the additional thing that this check fails as runtime activity has not been set for forward diff of BLAS.

albertomercurio avatar Oct 12 '25 14:10 albertomercurio

@albertomercurio I'm confused what you're expecting. Forward mode propagates derivatives from input to output. In your orignial code [lmk if I should look at something else], p is not modified, as a result dp is correctly not modified.

wsmoses avatar Nov 08 '25 22:11 wsmoses

The derivative should propagate from p to du. IIRC, this is giving a zero on the du even though the parameter has a derivative 1 and it's tied via mul!(du, A1, u, -p[1], zero(T)); mul!(du, A2, u, p[2], one(T)). I think it's not propagating the previous du derivative through the second mul! 5-arg form, i.e. it's not calculating the deriative of A B α + C β but instead A B α possibly.

ChrisRackauckas avatar Nov 09 '25 11:11 ChrisRackauckas