SciMLSensitivity.jl icon indicating copy to clipboard operation
SciMLSensitivity.jl copied to clipboard

tgrad support for adjoint problem

Open 000Justin000 opened this issue 6 years ago • 14 comments

If the ODEProblem has the field tgrad, and we construct a ODEAdjointProblem from it, it seems that the resulting problem does not have the field tgrad. Can you add support to it?

000Justin000 avatar Mar 05 '19 18:03 000Justin000

In other words, can you add the tgrad field to the following struct?

struct ODEAdjointSensitivityFunction{dgType,rateType,uType,F,J,PJ,UF,PF,G,JC,GC,A,SType,DG,MM,TJ,PJT,PJC,CP,INT} <: SensitivityFunction
  f::F
  jac::J
  paramjac::PJ
  uf::UF
  pf::PF
  g::G
  J::TJ
  pJ::PJT
  dg_val::dgType
  jac_config::JC
  g_grad_config::GC
  paramjac_config::PJC
  alg::A
  numparams::Int
  numindvar::Int
  f_cache::rateType
  discrete::Bool
  y::uType
  sol::SType
  dg::DG
  mass_matrix::MM
  checkpoints::CP
  integrator::INT
  tgrad # add tgrad field here!
end

000Justin000 avatar Mar 05 '19 18:03 000Justin000

Yeah, we should be able to and have it work. Right now does it error or does it just drop off the user tgrad and have to use autodiff?

ChrisRackauckas avatar Mar 05 '19 18:03 ChrisRackauckas

Thanks for the fast reply!

Right now it drops off the user tgrad and use autodiff. Moreover, the autodiff gives an error when calling dudt with a ForwardDiff.Dual object (instead of a julia Float64 array):

expected Forward.Dual{..., Float64, 1} got Forward.Dual{..., Forward.Dual{Nothing, Float64, 1}, 1} where I am using ... to represent a very long type.

000Justin000 avatar Mar 05 '19 19:03 000Justin000

Another commet is that, in the ODEAdjointProblem function at adjoint_sensitivity.jl, an ODEAdjointSensitivityFunction is converted to ODEFunction object in the following line,

ODEProblem(sense,z0,tspan,p,callback=_cb)

However, when converting an ODEAdjointSensitivityFunction object to ODEFunction, the fields such as jac, tgrad are lost. Would you please also fix that? Thanks!

000Justin000 avatar Mar 05 '19 23:03 000Justin000

The Jacobian of the adjoint problem is not the Jacobian of the forward problem. The Jacobian of the ODE is used internally in the stepping function though: https://github.com/JuliaDiffEq/DiffEqSensitivity.jl/blob/master/src/adjoint_sensitivity.jl#L106-L108 . As for tgrad , the time-gradient of the ODE is not the time-gradient of the adjoint equation either. This is more difficult since you'll need to do the chain rule on the Jacobian of the ODE as well. It'll take a bit to write something out and derive a useful form.

ChrisRackauckas avatar Mar 06 '19 00:03 ChrisRackauckas

Thank you Chris, you just saved me a lot of time.

Right now I think the best way for me is to use the autodiff and make it work. Actually I am just trying to run the following code:

using Random
using LinearAlgebra
using DifferentialEquations
using Flux
using DiffEqFlux

struct DETN
    p; q; F; G; Z; L;
end

DETN(p::Int,q::Int) = DETN(p, q, Dense(p+q,p+q,sigmoid),
                                 Dense(p+q,p+q,sigmoid),
                                 Dense(p+q,p,sigmoid),
                                 Dense(p+q,q,sigmoid));
Flux.@treelike DETN

p = 5;
q = 1;
dt = 0.05f0;
sigma = 0.10f0;
tspan = (0.0f0,100.0f0);
tsave = tspan[1]:dt:tspan[2];

function noise(q,t)
    vv = Vector(undef,q);
    for i in 1:length(vv)
        vv[i] = exp.(-0.5*((50.0.-t)./sigma).^2.0)./(sigma*sqrt(2pi));
    end
    return Float32.(vv);
end

function neural_ode(model::DETN,x,tspan,args...; kwargs...)
    Flux.Tracker.istracked(x) && error("u0 is not currently differentiable.")
    p = DiffEqFlux.destructure(model)
    #-------------------------------------
    function dudt_(du,u,p,t)
        B = DiffEqFlux.restructure(model,p);
        du .= -B.F(u).*u + B.G(u).*vcat(B.Z(u),noise(B.q,t));
    end
    #-------------------------------------
    func = ODEFunction(dudt_);
    prob = ODEProblem(func,x,tspan,p);
    return diffeq_adjoint(p,prob,args...;kwargs...);
end

Random.seed!(0);
u0 = vcat(rand(p),zeros(q));
B = DETN(p,q);
ps = Flux.params(B);
solver = Rosenbrock23();
sol = neural_ode(B,u0,tspan,solver,saveat=tsave,reltol=1e-6,abstol=1e-6,dt=0.05);

And it gives me the following error,

ERROR: LoadError: TypeError: in setindex!, in typeassert, expected ForwardDiff.Dual{ForwardDiff.Tag{DiffEqDiffTools.TimeGradientWrapper{ODEFunction{true,getfield(Main, Symbol("#dudt_#4")){DETN},UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Float64,1},Array{Float32,1}},Float32},Float64,1}, got ForwardDiff.Dual{ForwardDiff.Tag{DiffEqDiffTools.TimeGradientWrapper{ODEFunction{true,getfield(Main, Symbol("#dudt_#4")){DETN},UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Float64,1},Array{Float32,1}},Float32},ForwardDiff.Dual{Nothing,Float64,1},1}

The error occurs at the step when the autodiff calls my dudt_ with a ForwardDiff.Dual object, but I am not an expert in the ForwardDiff package. Would you help me with this?

Thanks in advance! Junteng

000Justin000 avatar Mar 06 '19 00:03 000Justin000

What about using solver = Rosenbrock23(autodiff=false)?

ChrisRackauckas avatar Mar 07 '19 19:03 ChrisRackauckas

It is very slow.

000Justin000 avatar Mar 07 '19 19:03 000Justin000

Adjoints are a long calculation. Is there a speed baseline that we know this should be matching?

I would speed this up by instead using Rodas5(autodiff=false,diff_type=Val{:forward}), or maybe an SDIRK method can do better. But I'm not sure that autodiff vs not autodiff inside the Rosenbrock routine would make much of a difference here. The main problem is the number of f calls and W inversions. Maybe qsteady with an SDIRK method? Or maybe BDF is just a solid method here since it's reducing f calls for more but small W inversions? CVODE_BDF() would then be an interesting thing to try here. Or checkpointing? Finding the method that optimizes this is a very different point.

ChrisRackauckas avatar Mar 07 '19 20:03 ChrisRackauckas

Thank you for the advice, I will play with all those

000Justin000 avatar Mar 07 '19 20:03 000Justin000

We're also playing with a pretty huge generated example today, so we might want to loop around and check the performance here after we optimize a few SDIRK methods. @YingboMa

ChrisRackauckas avatar Mar 07 '19 20:03 ChrisRackauckas

It's glad to know

000Justin000 avatar Mar 07 '19 20:03 000Justin000

You probably don't want to use Rosenbrock methods for this. I will look into it.

YingboMa avatar Mar 07 '19 21:03 YingboMa

Thanks!

000Justin000 avatar Mar 08 '19 00:03 000Justin000