DiffEqFlux.jl
DiffEqFlux.jl copied to clipboard
Start using the less verbose `Lux.@compact` API
Current version
@concrete struct NeuralODE{M <: AbstractExplicitLayer} <: NeuralDELayer
model::M
tspan
args
kwargs
end
function NeuralODE(model, tspan, args...; kwargs...)
!(model isa AbstractExplicitLayer) && (model = Lux.transform(model))
return NeuralODE(model, tspan, args, kwargs)
end
function (n::NeuralODE)(x, p, st)
model = StatefulLuxLayer(n.model, nothing, st)
dudt(u, p, t) = model(u, p)
ff = ODEFunction{false}(dudt; tgrad = basic_tgrad)
prob = ODEProblem{false}(ff, x, n.tspan, p)
return (
solve(prob, n.args...;
sensealg = InterpolatingAdjoint(; autojacvec = ZygoteVJP()), n.kwargs...),
model.st)
end
This would become: (argument splatting args...
won't work but keyword argument splatting kwargs...
is fine)
function NeuralODE(model, tspan, solver = nothing; kwargs...)
!(model isa AbstractExplicitLayer) && (model = FromFluxAdaptor()(model))
return @compact(; model, tspan, solver, sensealg=InterpolatingAdjoint(; autojacvec = ZygoteVJP()), kwargs...) do x, p
dudt(u, p, t) = model(u, p)
prob = ODEProblem(ODEFunction{false}(dudt; tgrad = basic_tgrad), x, n.tspan, p.model)
@return solve(prob, solver; sensealg, kwargs...)
end
end
Also this handles all the boxing issues automatically (the reason we had to add the StatefulLuxLayer
)
Not sure if this is considered breaking. The end user wont be able to do foo(::NeuralODE)
after this. But we don't guarantee that (considering the NonlinearSolve.jl precedent where we made algorithms into functions and not types).
Needs https://github.com/LuxDL/Lux.jl/pull/584 which will be released later today