Optimization.jl icon indicating copy to clipboard operation
Optimization.jl copied to clipboard

Optimize for N epochs. Invoke custom callback at the end of an epoch.

Open lungd opened this issue 4 years ago • 4 comments

I want to propose some changes but I am not entirely sure where they would fit best.

Let's assume there is an iterator for a training set. Often, the suggested method to train/optimize for N epochs is to ncycle the training set. I can define a callback to log some information (e.g. the loss) which gets passed to the optimizer and invoked at the end of an iteration. Because the training set can contain many samples I don't want to print the loss of every sample but the mean at the end of an epoch (all samples have been taken). Printing the samples' losses should still be possible though.

I can imagine a __solve() function accepting all the optimizer and some additional arguments (epochs, epoch_cb).

Here are some parts of my code. It already works but could be more customizable.

# maybe something like __solve(prob::OptimizationProblem, opt, data;
                 cb = (args...) -> (false), ecb = (args...) -> (false), epochs = nothing,
                 kwargs...)

  porgress = ProgressMeter.Progress(length(data); showspeed=true)
  losses = Float32[]
  function _cb(p,l,args...; kwargs...)
    push!(losses,l)
    x = "$(losses[end])"
    ProgressMeter.next!(porgress; showvalues = [(:loss,x)])
    cb(p,l,args...; kwargs...)
  end
  ...
  optprob = GalacticOptim.OptimizationProblem(...)
  ...
  porgress = ProgressMeter.Progress(length(data); showspeed=true)
  res = GalacticOptim.solve(optprob, opt, data, cb = _cb)
  ecb(losses,1,res)
  losses = Float32[]
  for epoch in 2:epochs
    prob = remake(optprob,u0=res.u)
    porgress = ProgressMeter.Progress(length(data); showspeed=true)
    res = GalacticOptim.solve(prob, opt, data, cb = _cb)
    ecb(losses,epoch,res)
    losses = Float32[]
  end
  return res
end

The output I get:

Epoch 019/100, mean train_loss:3.6367512
Progress: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:00:44 ( 1.17  s/it)
  loss:  3.7939742
Epoch 020/100, mean train_loss:3.578114
Progress:  61%|███████████████████████████████████████████████████████████████                                         |  ETA: 0:00:18 ( 1.20  s/it)
  loss:  2.0184839

Is something like this possible already? If not and you think it's worth creating a PR I am happy to do so.

lungd avatar Aug 11 '21 18:08 lungd

Interesting, and default epochs to 1? I could see that being useful. Will need good docs though.

ChrisRackauckas avatar Aug 13 '21 20:08 ChrisRackauckas

What's wrong with putting a counter in the callbacks?

ChrisRackauckas avatar Aug 13 '21 20:08 ChrisRackauckas

Sure, you could make use of a counter to define what to do after X iterations.

cb()
  # sample_callback()
  if counter % X == 0
    # "epoch"_callback()
  end
  counter += 1
end

That's what you have in mind, right? While this is indeed not that hard to implement why don't give the user the option to just define the body of different callbacks and let the package take care of everything else?

Actually, now I think a third callback could make sense (in case there is a data loader). Some commonly used callbacks (in terms of the time of invocation):

  1. sample_callback (SC): invoke after every iteration
  2. epoch_callback (EC): invoke after all training samples have been taken.
  3. iterative_callback (IC): invoke after X iterations 4 5) start/end callback? (I will stop here)

The function could handle following cases:

  1. No data Invoke SC every 1? Invoke EC every 1? Invoke IC every X epochs = maxIters ? In this case I think only SC or EC makes sense and either epochs or maxIters

  2. With Data loader (dl) Invoke SC every 1 Invoke EC every length(dl) Invoke IC every X If length(dl) = 10 the user could set epochs to 10 and maxIters to 90 (I don't know why but that would be the max number of iterations or training samples)

  3. I was also thinking about an array of optimizers and epochs, e.g. ADAM for 300 epochs, BFGS for 200 epochs while still using maxIters as upper bound.

I agree to epochs=1 as default.

What do you think?

lungd avatar Aug 13 '21 23:08 lungd

That all just seems overly complicated. Why not just have constructors for standard callback types?

ChrisRackauckas avatar Aug 14 '21 12:08 ChrisRackauckas