SciMLSensitivity.jl
SciMLSensitivity.jl copied to clipboard
tgrad support for adjoint problem
If the ODEProblem has the field tgrad, and we construct a ODEAdjointProblem from it, it seems that the resulting problem does not have the field tgrad. Can you add support to it?
In other words, can you add the tgrad field to the following struct?
struct ODEAdjointSensitivityFunction{dgType,rateType,uType,F,J,PJ,UF,PF,G,JC,GC,A,SType,DG,MM,TJ,PJT,PJC,CP,INT} <: SensitivityFunction
f::F
jac::J
paramjac::PJ
uf::UF
pf::PF
g::G
J::TJ
pJ::PJT
dg_val::dgType
jac_config::JC
g_grad_config::GC
paramjac_config::PJC
alg::A
numparams::Int
numindvar::Int
f_cache::rateType
discrete::Bool
y::uType
sol::SType
dg::DG
mass_matrix::MM
checkpoints::CP
integrator::INT
tgrad # add tgrad field here!
end
Yeah, we should be able to and have it work. Right now does it error or does it just drop off the user tgrad and have to use autodiff?
Thanks for the fast reply!
Right now it drops off the user tgrad and use autodiff. Moreover, the autodiff gives an error when calling dudt with a ForwardDiff.Dual object (instead of a julia Float64 array):
expected Forward.Dual{..., Float64, 1} got Forward.Dual{..., Forward.Dual{Nothing, Float64, 1}, 1} where I am using ... to represent a very long type.
Another commet is that, in the ODEAdjointProblem function at adjoint_sensitivity.jl, an ODEAdjointSensitivityFunction is converted to ODEFunction object in the following line,
ODEProblem(sense,z0,tspan,p,callback=_cb)
However, when converting an ODEAdjointSensitivityFunction object to ODEFunction, the fields such as jac, tgrad are lost. Would you please also fix that? Thanks!
The Jacobian of the adjoint problem is not the Jacobian of the forward problem. The Jacobian of the ODE is used internally in the stepping function though: https://github.com/JuliaDiffEq/DiffEqSensitivity.jl/blob/master/src/adjoint_sensitivity.jl#L106-L108 . As for tgrad , the time-gradient of the ODE is not the time-gradient of the adjoint equation either. This is more difficult since you'll need to do the chain rule on the Jacobian of the ODE as well. It'll take a bit to write something out and derive a useful form.
Thank you Chris, you just saved me a lot of time.
Right now I think the best way for me is to use the autodiff and make it work. Actually I am just trying to run the following code:
using Random
using LinearAlgebra
using DifferentialEquations
using Flux
using DiffEqFlux
struct DETN
p; q; F; G; Z; L;
end
DETN(p::Int,q::Int) = DETN(p, q, Dense(p+q,p+q,sigmoid),
Dense(p+q,p+q,sigmoid),
Dense(p+q,p,sigmoid),
Dense(p+q,q,sigmoid));
Flux.@treelike DETN
p = 5;
q = 1;
dt = 0.05f0;
sigma = 0.10f0;
tspan = (0.0f0,100.0f0);
tsave = tspan[1]:dt:tspan[2];
function noise(q,t)
vv = Vector(undef,q);
for i in 1:length(vv)
vv[i] = exp.(-0.5*((50.0.-t)./sigma).^2.0)./(sigma*sqrt(2pi));
end
return Float32.(vv);
end
function neural_ode(model::DETN,x,tspan,args...; kwargs...)
Flux.Tracker.istracked(x) && error("u0 is not currently differentiable.")
p = DiffEqFlux.destructure(model)
#-------------------------------------
function dudt_(du,u,p,t)
B = DiffEqFlux.restructure(model,p);
du .= -B.F(u).*u + B.G(u).*vcat(B.Z(u),noise(B.q,t));
end
#-------------------------------------
func = ODEFunction(dudt_);
prob = ODEProblem(func,x,tspan,p);
return diffeq_adjoint(p,prob,args...;kwargs...);
end
Random.seed!(0);
u0 = vcat(rand(p),zeros(q));
B = DETN(p,q);
ps = Flux.params(B);
solver = Rosenbrock23();
sol = neural_ode(B,u0,tspan,solver,saveat=tsave,reltol=1e-6,abstol=1e-6,dt=0.05);
And it gives me the following error,
ERROR: LoadError: TypeError: in setindex!, in typeassert, expected ForwardDiff.Dual{ForwardDiff.Tag{DiffEqDiffTools.TimeGradientWrapper{ODEFunction{true,getfield(Main, Symbol("#dudt_#4")){DETN},UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Float64,1},Array{Float32,1}},Float32},Float64,1}, got ForwardDiff.Dual{ForwardDiff.Tag{DiffEqDiffTools.TimeGradientWrapper{ODEFunction{true,getfield(Main, Symbol("#dudt_#4")){DETN},UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Float64,1},Array{Float32,1}},Float32},ForwardDiff.Dual{Nothing,Float64,1},1}
The error occurs at the step when the autodiff calls my dudt_ with a ForwardDiff.Dual object, but I am not an expert in the ForwardDiff package. Would you help me with this?
Thanks in advance! Junteng
What about using solver = Rosenbrock23(autodiff=false)?
It is very slow.
Adjoints are a long calculation. Is there a speed baseline that we know this should be matching?
I would speed this up by instead using Rodas5(autodiff=false,diff_type=Val{:forward}), or maybe an SDIRK method can do better. But I'm not sure that autodiff vs not autodiff inside the Rosenbrock routine would make much of a difference here. The main problem is the number of f calls and W inversions. Maybe qsteady with an SDIRK method? Or maybe BDF is just a solid method here since it's reducing f calls for more but small W inversions? CVODE_BDF() would then be an interesting thing to try here. Or checkpointing? Finding the method that optimizes this is a very different point.
Thank you for the advice, I will play with all those
We're also playing with a pretty huge generated example today, so we might want to loop around and check the performance here after we optimize a few SDIRK methods. @YingboMa
It's glad to know
You probably don't want to use Rosenbrock methods for this. I will look into it.
Thanks!