mlr3torch icon indicating copy to clipboard operation
mlr3torch copied to clipboard

Continue training a learner

Open sebffischer opened this issue 2 years ago • 3 comments

Once we have trained a LearnerTorch, it is important to have an option to $continue() training it. As an example consider the graph learner below:

graph = po("pca") %>>%
  po("torch_ingress_num") %>>%
  po("nn_head") %>>%
  po("torch_loss", "cross_entropy") %>>%
  po("torch_optimizer", "adam") %>>%
  po("torch_model", batch_size = 16, epochs = 100, device = "cuda"))

learner = as_learner(graph)
learner$train(tsk("mtcars"))  

The question is now how to conveniently continue the training of the learner.

Idealy we can simply do

learner$param_set$set_values(
  epochs = 50, # an additional 50 epochs
  batch_size = 32
)
learner$continue(tsk("mtcars"))

For this to work, we either need to overwrite the as_learner() method of mlr3, or the LearnerSupervised class must implement a $continue() method.

Ideally, the mlr3::Learner class provides a $continue() method in combination with a "continuable" property, similar (but not equivalent) to the hotstarting idea. It is different because not only a one parameter value can be changed (e.g. we might change the batch-size and the epochs as shown above). Also the intended UI shown above is more user-friendly than the hotstarting UI.

Because the above learner is a GraphLearner, the $continue() call must eventually be executed by the Graph class. One way how the Graph implements the continue would be throwing an error if none of the pipeops has a $continue() method. If at least one of the contained pipeops has a $continue() method we could call graph_reduce() and call $train() if the PipeOp does not have a $continue() method. The $continue() method of all pipeops with naming scheme PipeOpTorch<XXX> should simply ignore the $continue() call. I.e. the PipeOpTorchModel class needs to implement the $continue() method (and will redirect this responsibility it to LearnerTorchModel).

There are probably a lot of cases to consider, but this is the general idea that at least seems convenient from a user-perspective and I think works in the simple example above.

Note that $continue() should either overwrite the learner's model or modify it in-place, as $predict() should work as usual.

Open questions:

  • does this work out?
  • How to define the learner's state (regarding the other meta-information like logs etc.)?
  • Must the task used during $continue() be identical to the task used during $train()? (probably not, in this case $continue() should also be related to finetuning of a torch learner.

sebffischer avatar Oct 17 '23 15:10 sebffischer

This is basically equivalent to hotstarting with a slight UI change so I think we should rely on the code for hotstarting

sebffischer avatar Jun 14 '24 12:06 sebffischer

The callbacks also need to be continuable, e.g. some of the learning rate schedulers need to set the last_epoch so they can properly be resumed. Need to handle this!

sebffischer avatar Jan 14 '25 14:01 sebffischer

I think for now we can just offer a t_clbk("resume", weights_path, optim_path, epochs) where one can pass paths to the weights and optimizer as well as the epochs (possibly inferring some of this). There should also be a way to atomatically continue from a checkpoint as obtained by t_clbk("checkpoint").

sebffischer avatar Jan 23 '25 09:01 sebffischer