DifferentialEquations.jl
DifferentialEquations.jl copied to clipboard
Broken type stability when using auto-switching ode solvers
I am trying to compute some loss function of an ODE problem, and while checking the type-stability of my algorithm, I found out that using auto-switching ode solvers broke the type-stability of my function. Below is a minimal reproducible example of the problem:
using OrdinaryDiffEq, Test
function loss(p, alg)
function f(u,p,t)
du1 = p[1] * u[1] - p[2] * u[1] * u[2]
du2 = p[3] * u[1] * u[2] - p[4] * u[2]
return [du1, du2]
end
u0 = [2., 1.]
tspan = (0., 10.)
prob = ODEProblem{false}(f, u0, tspan, p)
sol = solve(prob, alg)
uf = sol.u[end]
res1 = u0[1] - 1.0
res2 = u0[2] - 1.0
res3 = uf[1] - 1.0
res4 = uf[2] - 1.0
return [res1, res2, res3, res4]
end
p = Float64[1.0, 1.0, 1.0, 1.0]
# Type stable
@inferred loss(p, Tsit5())
@inferred loss(p, Vern7())
# Type unstable
@inferred loss(p, AutoTsit5(Rosenbrock23()))
@inferred loss(p, AutoVern7(Rosenbrock23()))