DiffEqFlux.jl icon indicating copy to clipboard operation
DiffEqFlux.jl copied to clipboard

suggested addition to UDE tutorial

Open cems2 opened this issue 5 years ago • 1 comments

Here's a chunk of code that one could add to the UDE tutorial.

This code illustrates four things.

  1. Minibatches in the loss function
  2. Dynamic Composition of layers to form a model inside the loss function
  3. Training on models that have both fixed and variable parameters.
  4. How to used neural ODEs with different starting U0 in the minibatch.

I know from talking to Chris R. that those four points being illustrated seem staggeringly obvious once you know how to do them! But, I can promise you, the current tutorials using NeuralODE tend to mislead about how one gathers the parameters and how one designates which parameters get optimized and how the loss function has to reconstruct the model using both the optimized and fixed parameters and the apparent fact that the u0 sent into ODEProblem is ignorable!

using Flux #for ADAM 
using OrdinaryDiffEq
using DiffEqFlux

model1 = FastChain(FastDense(2,2))
model2 = FastChain(FastDense(2,2))
nontrainable = initial_params(model2)
p = initial_params(model1)

f(x,p,t) = model2(model1(x,p),nontrainable)

u0 = Float32[1.0,1.0]
u1 = Float32[0.0,1.0]
u2 = Float32[1.0,0.0]

minibatch = [u0,u1,u2]

prob = ODEProblem(f,u0,(0.0f0,1.0f0),p)  # u0 is a placeholder. ignored later

function loss_batch(p,mb)
   s =0.0f0

   for u in mb    # note this loop is better done in parallel in concrete solve just by using a 2D u0
      preds = concrete_solve(prob,Tsit5(),u,p, saveat=0.01)
      s += sum(abs, (preds.-u))
   end
   s
end

loss(p) ->   loss_batch(p,minibatch)

function cb1(args...)
   println("args:",args[1:2])
   false
end

res0 = DiffEqFlux.sciml_train(loss,p,ADAM(0.005),maxiters=300,cb=cb1)
res1 = DiffEqFlux.sciml_train(loss,res0.minimizer,ADAM(0.0005),maxiters=300,cb=cb1)

Comments: The above trivial example effectively finds a second layer that is the inverse of the first layer so that the input is the output. to do this it requires a complete basis in the minibatch.

by the way, the above example fails for BFGS optimization. If you remove the minibatch and just use a single u0 input then the BFGS stage works. But if you don't then it barks out instability/abort warnings and fails.

# the folloing will give an error about instability.  I'm not sure why.
using Optim # for BFGS
res = DiffEqFlux.sciml_train(loss,res1.minimizer,BFGS() ,maxiters=30,cb=cb1)

julia> res
 * Status: failure (line search failed)

 * Candidate solution
    Minimizer: [-7.96e-03, -6.98e-03, -7.91e-03,  ...]
    Minimum:   4.455010e-01

 * Found with
    Algorithm:     BFGS
    Initial Point: [-7.88e-03, -7.07e-03, -8.33e-03,  ...]

 * Convergence measures
    |x - x'|               = 3.46e-04 ≰ 0.0e+00
    |x - x'|/|x'|          = 4.17e-02 ≰ 0.0e+00
    |f(x) - f(x')|         = 9.45e-03 ≰ 0.0e+00
    |f(x) - f(x')|/|f(x')| = 2.12e-02 ≰ 0.0e+00
    |g(x)|                 = 1.35e+02 ≰ 1.0e-08

 * Work counters
    Seconds run:   2  (vs limit Inf)
    Iterations:    2
    f(x) calls:    33
    ∇f(x) calls:   33

cems2 avatar Feb 19 '20 04:02 cems2

@collinwarner can you make sure these ideas get into the docs, maybe this as an advanced tutorial?

ChrisRackauckas avatar May 12 '20 17:05 ChrisRackauckas

We have tutorials covering these pieces nowadays.

ChrisRackauckas avatar Nov 22 '23 15:11 ChrisRackauckas