DelayDiffEq.jl icon indicating copy to clipboard operation
DelayDiffEq.jl copied to clipboard

Pass full integrator instead of parameters

Open devmotion opened this issue 5 years ago • 5 comments

As discussed in https://github.com/JuliaDiffEq/DiffEqProblemLibrary.jl/pull/39, especially for the history function it seems reasonable to pass the full integrator as argument instead of only the parameters, i.e., having h(integrator, t) instead of h(p, t) and also f(u, h, integrator, t) instead of f(u, h, p, t). This would enable the user to write generic history functions with correct output types (see the discussion in the PR) and hopefully allow to simplify the implementation in DelayDiffEq.

According to @ChrisRackauckas

we should have a common arg for using integrator instead of p, and then we just need to make every package handle that well.

I think we should approach this issue slightly differently. A user has to decide whether to pass around the integrator or only the parameters already when implementing f (or h), i.e., it is a property that does not depend on the numerical algorithm but rather of the differential equation function. Hence I guess it would make sense to handle this issue by modifying DiffEqFunctions instead of different algorithms. We could replace

abstract type AbstractDiffEqFunction{iip} <: Function end

with

abstract type AbstractDiffEqFunction{iip,unpackparams} <: Function end

and then define, e.g.,

(f::ODEFunction{true,unpackparams})(du, u, integrator, t) where unpackparams = unpackparams ? f.f(du, u, get_p(integrator), t) : f.f(du, u, integrator, t).

In that way, we just have to implement get_p for every integrator (which would be integrator.p by default) and could always pass integrator in every package.

devmotion avatar Jul 02 '19 08:07 devmotion

Yes, this makes sense. I am a little worried about compile times, but maybe it all just quickly compiles away.

ChrisRackauckas avatar Jul 02 '19 12:07 ChrisRackauckas

Yes, hopefully the compiler is smart enough.

However, there's another issue: in the same way we have to pass around integrator to the tgrad, analytic, etc. functions (or not). Of course, this could be ensured on the level of DiffEqFunction in the same way as in the example above by using overloads such as f(Val{:analytic}, ...). But since we switched away from this form I guess that's not a good idea :smile:

Alternatively, one could define functions such as

function DiffEqBase.analytic(f::ODEFunction{iip,unpack}, u, integrator, t) where {iip,unpack}
    has_analytic(f) || error("analytical solution is not defined")

    unpack ? f.analytic(u, get_p(integrator), t) : f.analytic(u, integrator, t)
end 

for all such overloads, but I don't know if this makes any difference.

I still like the idea of attacking this problem on the lowest level, but of course an alternative would be to explicitly define p before every (chunk of) function calls, e.g., by defining

function perform_step!(integrator, cache::BS3ConstantCache)
    p = unpack_params(integrator, integrator.f)
    .....
end

unpack_params(integrator::ODEIntegrator, ::ODEFunction{iip,false}) where iip = integrator
unpack_params(integrator::ODEIntegrator, ::ODEFunction{iip,true}) where iip = get_p(integrator)

devmotion avatar Jul 02 '19 15:07 devmotion

We can also hack it with getproperty overloading

ChrisRackauckas avatar Jul 02 '19 15:07 ChrisRackauckas

I'm working on a prototype for ODEFunction and I still hope that not too many changes are necessary in OrdinaryDiffEq.

However, I'm not sure how to deal with the fact that p is used to construct the cache in https://github.com/JuliaDiffEq/OrdinaryDiffEq.jl/blob/master/src/solve.jl#L246 before the ODEIntegrator exists. As far as I can see, p is mostly/only used to construct the Jacobian w.r.t u for the nonlinear solvers in lines such as https://github.com/JuliaDiffEq/DiffEqBase.jl/blob/master/src/nlsolve/utils.jl#L195 to evaluate f.jac(uprev, p, t). I mean, if jac is given we want to use it but I don't know how to retrieve its type if it expects a full integrator.

Can we get around this problem somehow by not caching W but passing it around when it's created?

devmotion avatar Jul 03 '19 21:07 devmotion

Since passing around the integrator in OrdinaryDiffEq is not completely straightforward (at least it seems to me), I started playing around with something that's more centered around the use case in DelayDiffEq. One idea was to use getproperty overloading such that all calls of @unpack f = integrator or integrator.f in OrdinaryDiffEq return an ODE Function with a history that is built on integrator, similar to the following simple example:

using DelayDiffEq, DiffEqBase, Test

struct ODEFunctionWrapper{iip,F,H} <: DiffEqBase.AbstractODEFunction{iip}
    f::F
    h::H
end

function wrap(prob::DDEProblem)
    ODEFunctionWrapper{isinplace(prob.f),typeof(prob.f),typeof(prob.h)}(prob.f, prob.h)
end

(f::ODEFunctionWrapper{false})(u, p, t) = f.f(u, f.h, p, t)
(f::ODEFunctionWrapper{true})(du, u, p, t) = f.f(du, u, f.h, p, t)

struct TestStruct{F,A}
    f::F
    a::A
end

function buildTestStruct(prob::DDEProblem, u, p, t)
    f = wrap(prob)
    a = f(u, p, t)

    TestStruct(f, a)
end

function buildTestStruct(prob::DDEProblem, du, u, p, t)
    f = wrap(prob)
    f(du, u, p, t)

    TestStruct(f, first(du))
end

function Base.getproperty(test::TestStruct, x::Symbol)
    if x === :f
        f = getfield(test, :f)
        if isinplace(f)
            (du, u, p, t) -> f.f(du, u, (p, t) -> [t * test.a], p, t)
        else
            (u, p, t) -> f.f(u, (p, t) -> t * test.a, p, t)
        end
    else
        getfield(test, x)
    end
end

function calc(test::TestStruct, u, p, t)
    f = test.f
    f(u, p, t)
end

function calc!(test::TestStruct, du, u, p, t)
    f = test.f
    f(du, u, p, t)
    nothing
end

function f_ip(du, u, h, p, t)
    du[1] = h(p, t)[1] - u[1]
    nothing
end

f_scalar(u, h, p, t) = h(p, t) - u

function test()
    prob_ip = DDEProblem(f_ip, [1.0], (p, t) -> [0.0], (0.0, 10.0))
    prob_scalar = DDEProblem(f_scalar, 1.0, (p, t) -> 0.0, (0.0, 10.0))

    wrap_ip = wrap(prob_ip)
    wrap_scalar = wrap(prob_scalar)

    a = [0.0]
    wrap_ip(a, [5.0], nothing, 0.0)
    @test a[1] == - 5.0
    wrap_ip(a, [5.0], nothing, 5.0)
    @test a[1] == - 5.0
    wrap_ip(a, [5.0], nothing, 10.0)
    @test a[1] == - 5.0

    @test wrap_scalar(5.0, nothing, 0.0) == - 5.0
    @test wrap_scalar(5.0, nothing, 5.0) == - 5.0
    @test wrap_scalar(5.0, nothing, 10.0) == - 5.0

    struct_ip = buildTestStruct(prob_ip, [0.0], [5.0], nothing, 4.0)
    @test struct_ip.a == -5.0

    struct_scalar = buildTestStruct(prob_scalar, 5.0, nothing, 4.0)
    @test struct_scalar.a == -5.0

    b = [0.0]
    calc!(struct_ip, b, [5.0], nothing, 1.0)
    @test b[1] == -10.0
    calc!(struct_ip, b, [5.0], nothing, 4.0)
    @test b[1] == -25.0

    @test calc(struct_scalar, 5.0, nothing, 2.0) == -15.0
    @test calc(struct_scalar, 5.0, nothing, 6.0) == -35.0
end

However, I'm not sure, how this will affect performance if it is possible at all.

devmotion avatar Jul 11 '19 07:07 devmotion