ModelingToolkit.jl icon indicating copy to clipboard operation
ModelingToolkit.jl copied to clipboard

ForwardDiff does not work directly through ODEProblem

Open hersle opened this issue 2 months ago • 3 comments

Auto-differentiating through remake(ODEProblem()) works, but not directly through ODEProblem():

using Test
using ModelingToolkit
using ModelingToolkit: t_nounits as t, D_nounits as D
using DifferentialEquations
using ForwardDiff

@testset "ForwardDiff through ODEProblem with vs. without remake" begin
    @parameters P
    @variables x(t)
    sys = structural_simplify(ODESystem([D(x) ~ P], t, [x], [P]; name=:sys))
    
    function x_at_1(P; use_remake = false)
        if use_remake
            prob = ODEProblem(sys, [x => 0.0], (0.0, 1.0), [sys.P => NaN])
            prob = remake(prob; p = [sys.P => P])
        else
            prob = ODEProblem(sys, [x => 0.0], (0.0, 1.0), [sys.P => P])
        end
        return solve(prob)(1.0)
    end

    @test_nowarn ForwardDiff.derivative(P -> x_at_1(P; use_remake=true),  1.0) # passes
    @test_nowarn ForwardDiff.derivative(P -> x_at_1(P; use_remake=false), 1.0) # fails
end

The last test errors with

  MethodError: no method matching Float64(::ForwardDiff.Dual{ForwardDiff.Tag{var"#22#30", Float64}, Float64, 1})

  Closest candidates are:
    (::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat
     @ Base rounding.jl:207
    (::Type{T})(::T) where T<:Number
     @ Core boot.jl:792
    Float64(::IrrationalConstants.Loghalf)
     @ IrrationalConstants C:\Users\herma\.julia\packages\IrrationalConstants\vp5v4\src\macro.jl:112
    ...

  Stacktrace:
    [1] convert(::Type{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{var"#22#30", Float64}, Float64, 1})
      @ Base .\number.jl:7
    [2] symconvert(::Type{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{var"#22#30", Float64}, Float64, 1})
      @ ModelingToolkit C:\Users\herma\.julia\packages\ModelingToolkit\kByuD\src\systems\parameter_buffer.jl:2
    [3] ModelingToolkit.MTKParameters(sys::ODESystem, p::Dict{SymbolicUtils.BasicSymbolic{Real}, ForwardDiff.Dual{ForwardDiff.Tag{var"#22#30", Float64}, Float64, 1}}, u0::Vector{Pair{Num, Float64}}; tofloat::Bool, use_union::Bool)       
      @ ModelingToolkit C:\Users\herma\.julia\packages\ModelingToolkit\kByuD\src\systems\parameter_buffer.jl:114
    [4] ModelingToolkit.MTKParameters(sys::ODESystem, p::Dict{SymbolicUtils.BasicSymbolic{Real}, ForwardDiff.Dual{ForwardDiff.Tag{var"#22#30", Float64}, Float64, 1}}, u0::Vector{Pair{Num, Float64}})
      @ ModelingToolkit C:\Users\herma\.julia\packages\ModelingToolkit\kByuD\src\systems\parameter_buffer.jl:13
    [5] process_DEProblem(constructor::Type, sys::ODESystem, u0map::Vector{Pair{Num, Float64}}, parammap::Vector{Pair{Num, ForwardDiff.Dual{ForwardDiff.Tag{var"#22#30", Float64}, Float64, 1}}}; implicit_dae::Bool, du0map::Nothing, version::Nothing, tgrad::Bool, jac::Bool, checkbounds::Bool, sparse::Bool, simplify::Bool, linenumbers::Bool, parallel::Symbolics.SerialForm, eval_expression::Bool, use_union::Bool, tofloat::Bool, symbolic_u0::Bool, u0_constructor::typeof(identity), guesses::Dict{Any, Any}, t::Float64, warn_initialize_determined::Bool, build_initializeprob::Bool, kwargs::@Kwargs{check_length::Bool})
      @ ModelingToolkit C:\Users\herma\.julia\packages\ModelingToolkit\kByuD\src\systems\diffeqs\abstractodesystem.jl:942
    [6] process_DEProblem
      @ C:\Users\herma\.julia\packages\ModelingToolkit\kByuD\src\systems\diffeqs\abstractodesystem.jl:834 [inlined]
    [7] (ODEProblem{true, SciMLBase.AutoSpecialize})(sys::ODESystem, u0map::Vector{Pair{Num, Float64}}, tspan::Tuple{Float64, Float64}, parammap::Vector{Pair{Num, ForwardDiff.Dual{ForwardDiff.Tag{var"#22#30", Float64}, Float64, 1}}}; callback::Nothing, check_length::Bool, warn_initialize_determined::Bool, kwargs::@Kwargs{})
      @ ModelingToolkit C:\Users\herma\.julia\packages\ModelingToolkit\kByuD\src\systems\diffeqs\abstractodesystem.jl:1085
    [8] (ODEProblem{true, SciMLBase.AutoSpecialize})(sys::ODESystem, u0map::Vector{Pair{Num, Float64}}, tspan::Tuple{Float64, Float64}, parammap::Vector{Pair{Num, ForwardDiff.Dual{ForwardDiff.Tag{var"#22#30", Float64}, Float64, 1}}})    
      @ ModelingToolkit C:\Users\herma\.julia\packages\ModelingToolkit\kByuD\src\systems\diffeqs\abstractodesystem.jl:1075
    [9] (ODEProblem{true})(::ODESystem, ::Vector{Pair{Num, Float64}}, ::Vararg{Any}; kwargs::@Kwargs{})
      @ ModelingToolkit C:\Users\herma\.julia\packages\ModelingToolkit\kByuD\src\systems\diffeqs\abstractodesystem.jl:1062
   [10] (ODEProblem{true})(::ODESystem, ::Vector{Pair{Num, Float64}}, ::Vararg{Any})
      @ ModelingToolkit C:\Users\herma\.julia\packages\ModelingToolkit\kByuD\src\systems\diffeqs\abstractodesystem.jl:1061
   [11] #ODEProblem#755
      @ C:\Users\herma\.julia\packages\ModelingToolkit\kByuD\src\systems\diffeqs\abstractodesystem.jl:1051 [inlined]
   [12] ODEProblem
      @ C:\Users\herma\.julia\packages\ModelingToolkit\kByuD\src\systems\diffeqs\abstractodesystem.jl:1050 [inlined]
   [13] (::var"#x_at_1#23"{var"#x_at_1#16#24"{ODESystem, Num}})(P::ForwardDiff.Dual{ForwardDiff.Tag{var"#22#30", Float64}, Float64, 1}; use_remake::Bool)     
      @ Main C:\Users\herma\Dropbox\School\UIO\Research\boltzmann\bug.jl:17    
   [14] (::var"#22#30")(P::ForwardDiff.Dual{ForwardDiff.Tag{var"#22#30", Float64}, Float64, 1})
      @ Main C:\Users\herma\Dropbox\School\UIO\Research\boltzmann\bug.jl:23    
   [15] derivative
      @ C:\Users\herma\.julia\packages\ForwardDiff\PcZ48\src\derivative.jl:14 [inlined]
    ...

Could it be unified to work in both ways?

I'm on a fresh updated master branch with ] status

  [0c46a032] DifferentialEquations v7.13.0
  [f6369f11] ForwardDiff v0.10.36
  [961ee093] ModelingToolkit v9.12.1 `https://github.com/SciML/ModelingToolkit.jl.git#master`
  [1dea7af3] OrdinaryDiffEq v6.74.1

hersle avatar Apr 23 '24 08:04 hersle

Doing it through the whole symbolic system is a bit harder, but it should be possible. You most likely don't want to be doing this though for performance reasons: this would cause extra compilation in the optimization loop. But if you're doing something like reinforcement learning and trying new equations then you might have some interesting applications of this.

ChrisRackauckas avatar Apr 29 '24 03:04 ChrisRackauckas

Thanks! In case it's helpful, here are some simplified details about my real use case:

I have an ODE model with, say, 3 variables constrained by 1 equation. I want the user to specify initial conditions for any 2 independent variables, and have MTK solve for the third dependent variable. I want all this to work with automatic differentiation.

Here is an example that works with FiniteDiff, but fails with ForwardDiff:

using Test
using ModelingToolkit
using ModelingToolkit: t_nounits as t, D_nounits as D
using DifferentialEquations
using ForwardDiff, FiniteDiff

# ODE with 3 independent variables and 1 constraint
@variables x(t) y(t) z(t) s(t)
sys = structural_simplify(ODESystem([
    D(x) ~ 0,
    D(y) ~ 0,
    D(z) ~ 0,
    s ~ x + y + z
], t, [x, y, z, s], []; name=:sys))

function solve_instance(ics)
    prob = ODEProblem(sys, [s => 1.0; ics], (0.0, 1.0)) # s0 = x0 + y0 + z0 = 1
    sol = solve(prob)
    return sol
end

# user solves with initial conditions for different subsets of variables
x0(y0, z0) = solve_instance([y => y0, z => z0])(0.0, idxs=x) # analytical: x0 = 1 - y0 - z0
y0(x0, z0) = solve_instance([x => x0, z => z0])(0.0, idxs=y) # analytical: y0 = 1 - x0 - z0
z0(y0, x0) = solve_instance([x => x0, y => y0])(0.0, idxs=z) # analytical: z0 = 1 - x0 - y0

# evaluation passes
@test x0(0.1, 0.2) ≈ y0(0.1, 0.2) ≈ z0(0.1, 0.2) ≈ 0.7

# gradient passes with FiniteDiff, but fails with ForwardDiff
x0(ics::AbstractArray) = x0(ics...)
y0(ics::AbstractArray) = y0(ics...)
z0(ics::AbstractArray) = z0(ics...)
for ∇ in [FiniteDiff.finite_difference_gradient, ForwardDiff.gradient]
    @test all(∇(x0, [0.1, 0.2]) .≈ ∇(y0, [0.1, 0.2]) .≈ ∇(z0, [0.1, 0.2]) .≈ [-1, -1])
end

This fails with a very similar error (and the same root cause?):

  Test threw exception
  Expression: all(∇(x0, [0.1, 0.2]) .≈ ∇(y0, [0.1, 0.2]) .≈ ∇(z0, [0.1, 0.2]) .≈ [-1, -1])
  MethodError: no method matching Float64(::ForwardDiff.Dual{ForwardDiff.Tag{typeof(y0), Float64}, Float64, 2})

  Closest candidates are:
    (::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat      
     @ Base rounding.jl:207
    (::Type{T})(::T) where T<:Number
     @ Core boot.jl:792
    Float64(::IrrationalConstants.Loghalf)
     @ IrrationalConstants C:\Users\herma\.julia\packages\IrrationalConstants\vp5v4\src\macro.jl:112
    ...

  Stacktrace:
    [1] convert(::Type{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{typeof(y0), Float64}, Float64, 2})
      @ Base .\number.jl:7
    [2] symconvert(::Type{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{typeof(y0), Float64}, Float64, 2})
      @ ModelingToolkit C:\Users\herma\.julia\dev\ModelingToolkit\src\systems\parameter_buffer.jl:2
    [3] ModelingToolkit.MTKParameters(sys::NonlinearSystem, p::Vector{Pair{SymbolicUtils.BasicSymbolic{Real}, Float64}}, u0::Dict{Any, Any}; tofloat::Bool, use_union::Bool)
      @ ModelingToolkit C:\Users\herma\.julia\dev\ModelingToolkit\src\systems\parameter_buffer.jl:114
    [4] MTKParameters
      @ C:\Users\herma\.julia\dev\ModelingToolkit\src\systems\parameter_buffer.jl:13 [inlined]
    [5] process_NonlinearProblem(constructor::Type, sys::NonlinearSystem, u0map::Dict{Any, Any}, parammap::Vector{Pair{SymbolicUtils.BasicSymbolic{Real}, Float64}}; version::Nothing, jac::Bool, checkbounds::Bool, sparse::Bool, simplify::Bool, linenumbers::Bool, parallel::Symbolics.SerialForm, eval_expression::Bool, use_union::Bool, tofloat::Bool, kwargs::@Kwargs{check_length::Bool})
      @ ModelingToolkit C:\Users\herma\.julia\dev\ModelingToolkit\src\systems\nonlinear\nonlinearsystem.jl:400
    [6] process_NonlinearProblem
      @ C:\Users\herma\.julia\dev\ModelingToolkit\src\systems\nonlinear\nonlinearsystem.jl:384 [inlined]
    [7] (NonlinearProblem{true})(sys::NonlinearSystem, u0map::Dict{Any, Any}, parammap::Vector{Pair{SymbolicUtils.BasicSymbolic{Real}, Float64}}; check_length::Bool, kwargs::@Kwargs{})
      @ ModelingToolkit C:\Users\herma\.julia\dev\ModelingToolkit\src\systems\nonlinear\nonlinearsystem.jl:435
    [8] (NonlinearProblem{true})(sys::NonlinearSystem, u0map::Dict{Any, Any}, parammap::Vector{Pair{SymbolicUtils.BasicSymbolic{Real}, Float64}})
      @ ModelingToolkit C:\Users\herma\.julia\dev\ModelingToolkit\src\systems\nonlinear\nonlinearsystem.jl:429
    [9] NonlinearProblem(::NonlinearSystem, ::Dict{Any, Any}, ::Vararg{Any}; kwargs::@Kwargs{})
      @ ModelingToolkit C:\Users\herma\.julia\dev\ModelingToolkit\src\systems\nonlinear\nonlinearsystem.jl:426
   [10] NonlinearProblem(::NonlinearSystem, ::Dict{Any, Any}, ::Vararg{Any})
      @ ModelingToolkit C:\Users\herma\.julia\dev\ModelingToolkit\src\systems\nonlinear\nonlinearsystem.jl:425
   [11] ModelingToolkit.InitializationProblem{true, SciMLBase.AutoSpecialize}(sys::ODESystem, t::Float64, u0map::Vector{Pair{Num, ForwardDiff.Dual{ForwardDiff.Tag{typeof(y0), Float64}, Float64, 2}}}, parammap::SciMLBase.NullParameters; guesses::Dict{Any, Any}, check_length::Bool, warn_initialize_determined::Bool, kwargs::@Kwargs{})       
      @ ModelingToolkit C:\Users\herma\.julia\dev\ModelingToolkit\src\systems\diffeqs\abstractodesystem.jl:1636
   [12] (ModelingToolkit.InitializationProblem{true})(::ODESystem, ::Float64, ::Vararg{Any}; kwargs::@Kwargs{guesses::Dict{Any, Any}, warn_initialize_determined::Bool})
      @ ModelingToolkit C:\Users\herma\.julia\dev\ModelingToolkit\src\systems\diffeqs\abstractodesystem.jl:1572
   [13] ModelingToolkit.InitializationProblem(::ODESystem, ::Float64, ::Vararg{Any}; kwargs::@Kwargs{guesses::Dict{Any, Any}, warn_initialize_determined::Bool})
      @ ModelingToolkit C:\Users\herma\.julia\dev\ModelingToolkit\src\systems\diffeqs\abstractodesystem.jl:1560
   [14] process_DEProblem(constructor::Type, sys::ODESystem, u0map::Vector{Pair{Num, ForwardDiff.Dual{ForwardDiff.Tag{typeof(y0), Float64}, Float64, 2}}}, parammap::SciMLBase.NullParameters; implicit_dae::Bool, du0map::Nothing, version::Nothing, tgrad::Bool, jac::Bool, checkbounds::Bool, sparse::Bool, simplify::Bool, linenumbers::Bool, parallel::Symbolics.SerialForm, eval_expression::Bool, use_union::Bool, tofloat::Bool, symbolic_u0::Bool, u0_constructor::typeof(identity), guesses::Dict{Any, Any}, t::Float64, warn_initialize_determined::Bool, build_initializeprob::Bool, kwargs::@Kwargs{check_length::Bool})
      @ ModelingToolkit C:\Users\herma\.julia\dev\ModelingToolkit\src\systems\diffeqs\abstractodesystem.jl:921
   [15] process_DEProblem
      @ C:\Users\herma\.julia\dev\ModelingToolkit\src\systems\diffeqs\abstractodesystem.jl:834 [inlined]
   [16] (ODEProblem{true, SciMLBase.AutoSpecialize})(sys::ODESystem, u0map::Vector{Pair{Num, ForwardDiff.Dual{ForwardDiff.Tag{typeof(y0), Float64}, Float64, 2}}}, tspan::Tuple{Float64, Float64}, parammap::SciMLBase.NullParameters; callback::Nothing, check_length::Bool, warn_initialize_determined::Bool, kwargs::@Kwargs{})
      @ ModelingToolkit C:\Users\herma\.julia\dev\ModelingToolkit\src\systems\diffeqs\abstractodesystem.jl:1085
   [17] ODEProblem
      @ C:\Users\herma\.julia\dev\ModelingToolkit\src\systems\diffeqs\abstractodesystem.jl:1075 [inlined]
   [18] (ODEProblem{true, SciMLBase.AutoSpecialize})(sys::ODESystem, u0map::Vector{Pair{Num, ForwardDiff.Dual{ForwardDiff.Tag{typeof(y0), Float64}, Float64, 2}}}, tspan::Tuple{Float64, Float64})
      @ ModelingToolkit C:\Users\herma\.julia\dev\ModelingToolkit\src\systems\diffeqs\abstractodesystem.jl:1075
   [19] (ODEProblem{true})(::ODESystem, ::Vector{Pair{Num, ForwardDiff.Dual{ForwardDiff.Tag{typeof(y0), Float64}, Float64, 2}}}, ::Vararg{Any}; kwargs::@Kwargs{})
      @ ModelingToolkit C:\Users\herma\.julia\dev\ModelingToolkit\src\systems\diffeqs\abstractodesystem.jl:1062
   [20] (ODEProblem{true})(::ODESystem, ::Vector{Pair{Num, ForwardDiff.Dual{ForwardDiff.Tag{typeof(y0), Float64}, Float64, 2}}}, ::Vararg{Any})
      @ ModelingToolkit C:\Users\herma\.julia\dev\ModelingToolkit\src\systems\diffeqs\abstractodesystem.jl:1061
   [21] ODEProblem(::ODESystem, ::Vector{Pair{Num, ForwardDiff.Dual{ForwardDiff.Tag{typeof(y0), Float64}, Float64, 2}}}, ::Vararg{Any}; kwargs::@Kwargs{})
      @ ModelingToolkit C:\Users\herma\.julia\dev\ModelingToolkit\src\systems\diffeqs\abstractodesystem.jl:1051
   [22] ODEProblem(::ODESystem, ::Vector{Pair{Num, ForwardDiff.Dual{ForwardDiff.Tag{typeof(y0), Float64}, Float64, 2}}}, ::Vararg{Any}) 
      @ ModelingToolkit C:\Users\herma\.julia\dev\ModelingToolkit\src\systems\diffeqs\abstractodesystem.jl:1050
   [23] solve_instance(ics::Vector{Pair{Num, ForwardDiff.Dual{ForwardDiff.Tag{typeof(y0), Float64}, Float64, 2}}})
      @ Main .\REPL[9]:2
   [24] y0(x0::ForwardDiff.Dual{ForwardDiff.Tag{typeof(y0), Float64}, Float64, 2}, z0::ForwardDiff.Dual{ForwardDiff.Tag{typeof(y0), Float64}, Float64, 2})
      @ Main .\REPL[12]:1
   [25] y0
      @ .\REPL[18]:1 [inlined]
   [26] vector_mode_dual_eval!
      @ C:\Users\herma\.julia\packages\ForwardDiff\PcZ48\src\apiutils.jl:24 [inlined]
   [27] vector_mode_gradient(f::typeof(y0), x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{typeof(y0), Float64}, Float64, 2, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(y0), Float64}, Float64, 2}}})
      @ ForwardDiff C:\Users\herma\.julia\packages\ForwardDiff\PcZ48\src\gradient.jl:89
    ...

hersle avatar Apr 29 '24 09:04 hersle

Would it be more efficient to solve the initialization problem rather than create a new ODE problem? For example:

# continued from above code snippet

x0init(y0, z0) = solve(ModelingToolkit.InitializationProblem(sys, 0.0, [s => 1.0, y => y0, z => z0], []))[x]
@test x0init(0.1, 0.2) ≈ 0.7 # passes

x0init(ics::AbstractArray) = x0init(ics...)
for ∇ in [FiniteDiff.finite_difference_gradient, ForwardDiff.gradient]
    @test all(∇(x0init, [0.1, 0.2]) .≈ [-1, -1]) # fails with ForwardDiff
end

This also fails with ForwardDiff with the (seemingly) same error.

Having all this work seamlessly would be absolutely fantastic 😃

hersle avatar Apr 29 '24 09:04 hersle