DiffEqBase.jl icon indicating copy to clipboard operation
DiffEqBase.jl copied to clipboard

Enzyme Rule Fails for DDE

Open m-bossart opened this issue 8 months ago • 1 comments

The custom enzyme rule for solve_up fails when doing sensitivity analysis for a delay differential equation problem.

I expected the same rule that works for ODEProblem to work for DDEProblem

Minimal Reproducible Example 👇

using OrdinaryDiffEq
using DelayDiffEq
using SciMLSensitivity
using Enzyme
using Zygote
using Test

## Zygote matches Enzyme for ODEProblem 

f(du, u, p, t) = du .= u .* p
u0p = [2.0, 3.0]
function f(u0p)
    prob = ODEProblem{true}(f, u0p[1:1], (0.0, 1.0), u0p[2:2])
    sum(solve(prob, Rodas4(), abstol = 1e-9, reltol = 1e-9, saveat = 0.1))
du0p_zygote = Zygote.gradient(f, u0p)[1]
du0p = zeros(2)
Enzyme.autodiff(Reverse, f, Active, Duplicated(u0p, du0p))  
@test du0p_zygote == du0p

## Enzyme fails for DDEProblem

function f_delay(du, u, h, p, t)
     du .= u .* p .* h(p, t - 0.01)[1]
h(p, t) = ones(eltype(p), 2)
u0p = [2.0, 3.0]
function f(u0p)
    prob = DDEProblem{true}(f_delay, u0p[1:1], h, (0.0, 0.2), u0p[2:2], constant_lags = [0.1])
    sum(solve(prob, MethodOfSteps(Rodas4()), abstol = 1e-9, reltol = 1e-9, saveat = 0.1))
du0p_zygote = Zygote.gradient(f, u0p)[1]
du0p = zeros(2)
Enzyme.autodiff(Reverse, f, Active, Duplicated(u0p, du0p)) #Fails 
@test du0p_zygote == du0p

Error & Stacktrace ⚠️

ERROR: MethodError: no method matching MixedDuplicated(::ODESolution{…}, ::ODESolution{…})

Closest candidates are:
  MixedDuplicated(::T1, ::Base.RefValue{T1}) where T1
   @ EnzymeCore C:\Users\Matt Bossart\.julia\packages\EnzymeCore\a2poZ\src\EnzymeCore.jl:163
  MixedDuplicated(::T1, ::Base.RefValue{T1}, ::Bool) where T1
   @ EnzymeCore C:\Users\Matt Bossart\.julia\packages\EnzymeCore\a2poZ\src\EnzymeCore.jl:163

 [1] runtime_generic_augfwd(activity::Type{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::typeof(SciMLBase.wrap_sol), df::Nothing, primal_1::ODESolution{…}, shadow_1_1::ODESolution{…})
   @ Enzyme.Compiler C:\Users\Matt Bossart\.julia\packages\Enzyme\aioBJ\src\rules\jitrules.jl:147
Some type information was truncated. Use `show(err)` to see complete types.

Environment (please complete the following information):

  • Output of using Pkg; Pkg.status()
Status `C:\Users\Matt Bossart\OneDrive - UCB-O365\Desktop\Transient Stability\TestingEnzyme\Project.toml`
  [6e4b80f9] BenchmarkTools v1.5.0
  [052768ef] CUDA v5.4.2
  [d360d2e6] ChainRulesCore v1.24.0
  [b0b7db55] ComponentArrays v0.15.13
  [bcd4f6db] DelayDiffEq v5.47.3
  [2b5f629d] DiffEqBase v6.151.4
  [7da242da] Enzyme v0.12.17
  [26cc04aa] FiniteDifferences v0.12.32
  [f6369f11] ForwardDiff v0.10.36
  [929cbde3] LLVM v7.2.1
  [8913a72c] NonlinearSolve v3.13.0
  [1dea7af3] OrdinaryDiffEq v6.84.0
  [f0f68f2c] PlotlyJS v0.18.13
  [bed98974] PowerNetworkMatrices v0.10.3
  [398b2ede] PowerSimulationsDynamics v0.14.2
  [f00506e0] PowerSystemCaseBuilder v1.2.5
  [bcd98974] PowerSystems v3.3.0
  [295af30f] Revise v3.5.14
  [0bca4576] SciMLBase v2.41.3
  [1ed8b502] SciMLSensitivity v7.61.1
  [a759f4b9] TimerOutputs v0.5.24
  [e88e6eb3] Zygote v0.6.70
  • Output of versioninfo()
Julia Version 1.10.4
Commit 48d4fd4843 (2024-06-04 10:41 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Windows (x86_64-w64-mingw32)
  CPU: 16 × 11th Gen Intel(R) Core(TM) i7-11800H @ 2.30GHz
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, tigerlake)
Threads: 1 default, 0 interactive, 1 GC (on 16 virtual cores)

m-bossart avatar Jun 24 '24 13:06 m-bossart