DiffEqFlux.jl icon indicating copy to clipboard operation
DiffEqFlux.jl copied to clipboard

Start using the less verbose `Lux.@compact` API

Open avik-pal opened this issue 10 months ago • 0 comments

Current version

@concrete struct NeuralODE{M <: AbstractExplicitLayer} <: NeuralDELayer
    model::M
    tspan
    args
    kwargs
end

function NeuralODE(model, tspan, args...; kwargs...)
    !(model isa AbstractExplicitLayer) && (model = Lux.transform(model))
    return NeuralODE(model, tspan, args, kwargs)
end

function (n::NeuralODE)(x, p, st)
    model = StatefulLuxLayer(n.model, nothing, st)

    dudt(u, p, t) = model(u, p)
    ff = ODEFunction{false}(dudt; tgrad = basic_tgrad)
    prob = ODEProblem{false}(ff, x, n.tspan, p)

    return (
        solve(prob, n.args...;
            sensealg = InterpolatingAdjoint(; autojacvec = ZygoteVJP()), n.kwargs...),
        model.st)
end

This would become: (argument splatting args... won't work but keyword argument splatting kwargs... is fine)

function NeuralODE(model, tspan, solver = nothing; kwargs...)
    !(model isa AbstractExplicitLayer) && (model = FromFluxAdaptor()(model))
    return @compact(; model, tspan, solver, sensealg=InterpolatingAdjoint(; autojacvec = ZygoteVJP()), kwargs...) do x, p
        dudt(u, p, t) = model(u, p)
        prob = ODEProblem(ODEFunction{false}(dudt; tgrad = basic_tgrad), x, n.tspan, p.model)
        @return solve(prob, solver; sensealg, kwargs...)
    end
end

Also this handles all the boxing issues automatically (the reason we had to add the StatefulLuxLayer)

Not sure if this is considered breaking. The end user wont be able to do foo(::NeuralODE) after this. But we don't guarantee that (considering the NonlinearSolve.jl precedent where we made algorithms into functions and not types).

Needs https://github.com/LuxDL/Lux.jl/pull/584 which will be released later today

avik-pal avatar Apr 12 '24 21:04 avik-pal