DelayDiffEq.jl
DelayDiffEq.jl copied to clipboard
Pass full integrator instead of parameters
As discussed in https://github.com/JuliaDiffEq/DiffEqProblemLibrary.jl/pull/39, especially for the history function it seems reasonable to pass the full integrator as argument instead of only the parameters, i.e., having h(integrator, t)
instead of h(p, t)
and also f(u, h, integrator, t)
instead of f(u, h, p, t)
. This would enable the user to write generic history functions with correct output types (see the discussion in the PR) and hopefully allow to simplify the implementation in DelayDiffEq.
According to @ChrisRackauckas
we should have a common arg for using integrator instead of p, and then we just need to make every package handle that well.
I think we should approach this issue slightly differently. A user has to decide whether to pass around the integrator or only the parameters already when implementing f
(or h
), i.e., it is a property that does not depend on the numerical algorithm but rather of the differential equation function. Hence I guess it would make sense to handle this issue by modifying DiffEqFunction
s instead of different algorithms. We could replace
abstract type AbstractDiffEqFunction{iip} <: Function end
with
abstract type AbstractDiffEqFunction{iip,unpackparams} <: Function end
and then define, e.g.,
(f::ODEFunction{true,unpackparams})(du, u, integrator, t) where unpackparams = unpackparams ? f.f(du, u, get_p(integrator), t) : f.f(du, u, integrator, t).
In that way, we just have to implement get_p
for every integrator
(which would be integrator.p
by default) and could always pass integrator
in every package.
Yes, this makes sense. I am a little worried about compile times, but maybe it all just quickly compiles away.
Yes, hopefully the compiler is smart enough.
However, there's another issue: in the same way we have to pass around integrator
to the tgrad
, analytic
, etc. functions (or not). Of course, this could be ensured on the level of DiffEqFunction
in the same way as in the example above by using overloads such as f(Val{:analytic}, ...)
. But since we switched away from this form I guess that's not a good idea :smile:
Alternatively, one could define functions such as
function DiffEqBase.analytic(f::ODEFunction{iip,unpack}, u, integrator, t) where {iip,unpack}
has_analytic(f) || error("analytical solution is not defined")
unpack ? f.analytic(u, get_p(integrator), t) : f.analytic(u, integrator, t)
end
for all such overloads, but I don't know if this makes any difference.
I still like the idea of attacking this problem on the lowest level, but of course an alternative would be to explicitly define p
before every (chunk of) function calls, e.g., by defining
function perform_step!(integrator, cache::BS3ConstantCache)
p = unpack_params(integrator, integrator.f)
.....
end
unpack_params(integrator::ODEIntegrator, ::ODEFunction{iip,false}) where iip = integrator
unpack_params(integrator::ODEIntegrator, ::ODEFunction{iip,true}) where iip = get_p(integrator)
We can also hack it with getproperty overloading
I'm working on a prototype for ODEFunction
and I still hope that not too many changes are necessary in OrdinaryDiffEq.
However, I'm not sure how to deal with the fact that p
is used to construct the cache in https://github.com/JuliaDiffEq/OrdinaryDiffEq.jl/blob/master/src/solve.jl#L246 before the ODEIntegrator exists. As far as I can see, p
is mostly/only used to construct the Jacobian w.r.t u
for the nonlinear solvers in lines such as https://github.com/JuliaDiffEq/DiffEqBase.jl/blob/master/src/nlsolve/utils.jl#L195 to evaluate f.jac(uprev, p, t)
. I mean, if jac
is given we want to use it but I don't know how to retrieve its type if it expects a full integrator.
Can we get around this problem somehow by not caching W
but passing it around when it's created?
Since passing around the integrator in OrdinaryDiffEq is not completely straightforward (at least it seems to me), I started playing around with something that's more centered around the use case in DelayDiffEq. One idea was to use getproperty overloading such that all calls of @unpack f = integrator
or integrator.f
in OrdinaryDiffEq return an ODE Function with a history that is built on integrator
, similar to the following simple example:
using DelayDiffEq, DiffEqBase, Test
struct ODEFunctionWrapper{iip,F,H} <: DiffEqBase.AbstractODEFunction{iip}
f::F
h::H
end
function wrap(prob::DDEProblem)
ODEFunctionWrapper{isinplace(prob.f),typeof(prob.f),typeof(prob.h)}(prob.f, prob.h)
end
(f::ODEFunctionWrapper{false})(u, p, t) = f.f(u, f.h, p, t)
(f::ODEFunctionWrapper{true})(du, u, p, t) = f.f(du, u, f.h, p, t)
struct TestStruct{F,A}
f::F
a::A
end
function buildTestStruct(prob::DDEProblem, u, p, t)
f = wrap(prob)
a = f(u, p, t)
TestStruct(f, a)
end
function buildTestStruct(prob::DDEProblem, du, u, p, t)
f = wrap(prob)
f(du, u, p, t)
TestStruct(f, first(du))
end
function Base.getproperty(test::TestStruct, x::Symbol)
if x === :f
f = getfield(test, :f)
if isinplace(f)
(du, u, p, t) -> f.f(du, u, (p, t) -> [t * test.a], p, t)
else
(u, p, t) -> f.f(u, (p, t) -> t * test.a, p, t)
end
else
getfield(test, x)
end
end
function calc(test::TestStruct, u, p, t)
f = test.f
f(u, p, t)
end
function calc!(test::TestStruct, du, u, p, t)
f = test.f
f(du, u, p, t)
nothing
end
function f_ip(du, u, h, p, t)
du[1] = h(p, t)[1] - u[1]
nothing
end
f_scalar(u, h, p, t) = h(p, t) - u
function test()
prob_ip = DDEProblem(f_ip, [1.0], (p, t) -> [0.0], (0.0, 10.0))
prob_scalar = DDEProblem(f_scalar, 1.0, (p, t) -> 0.0, (0.0, 10.0))
wrap_ip = wrap(prob_ip)
wrap_scalar = wrap(prob_scalar)
a = [0.0]
wrap_ip(a, [5.0], nothing, 0.0)
@test a[1] == - 5.0
wrap_ip(a, [5.0], nothing, 5.0)
@test a[1] == - 5.0
wrap_ip(a, [5.0], nothing, 10.0)
@test a[1] == - 5.0
@test wrap_scalar(5.0, nothing, 0.0) == - 5.0
@test wrap_scalar(5.0, nothing, 5.0) == - 5.0
@test wrap_scalar(5.0, nothing, 10.0) == - 5.0
struct_ip = buildTestStruct(prob_ip, [0.0], [5.0], nothing, 4.0)
@test struct_ip.a == -5.0
struct_scalar = buildTestStruct(prob_scalar, 5.0, nothing, 4.0)
@test struct_scalar.a == -5.0
b = [0.0]
calc!(struct_ip, b, [5.0], nothing, 1.0)
@test b[1] == -10.0
calc!(struct_ip, b, [5.0], nothing, 4.0)
@test b[1] == -25.0
@test calc(struct_scalar, 5.0, nothing, 2.0) == -15.0
@test calc(struct_scalar, 5.0, nothing, 6.0) == -35.0
end
However, I'm not sure, how this will affect performance if it is possible at all.