Continue training a learner
Once we have trained a LearnerTorch, it is important to have an option to $continue() training it.
As an example consider the graph learner below:
graph = po("pca") %>>%
po("torch_ingress_num") %>>%
po("nn_head") %>>%
po("torch_loss", "cross_entropy") %>>%
po("torch_optimizer", "adam") %>>%
po("torch_model", batch_size = 16, epochs = 100, device = "cuda"))
learner = as_learner(graph)
learner$train(tsk("mtcars"))
The question is now how to conveniently continue the training of the learner.
Idealy we can simply do
learner$param_set$set_values(
epochs = 50, # an additional 50 epochs
batch_size = 32
)
learner$continue(tsk("mtcars"))
For this to work, we either need to overwrite the as_learner() method of mlr3, or the LearnerSupervised class must implement a $continue() method.
Ideally, the mlr3::Learner class provides a $continue() method in combination with a "continuable" property, similar (but not equivalent) to the hotstarting idea. It is different because not only a one parameter value can be changed (e.g. we might change the batch-size and the epochs as shown above). Also the intended UI shown above is more user-friendly than the hotstarting UI.
Because the above learner is a GraphLearner, the $continue() call must eventually be executed by the Graph class.
One way how the Graph implements the continue would be throwing an error if none of the pipeops has a $continue() method. If at least one of the contained pipeops has a $continue() method we could call graph_reduce() and call $train() if the PipeOp does not have a $continue() method.
The $continue() method of all pipeops with naming scheme PipeOpTorch<XXX> should simply ignore the $continue() call.
I.e. the PipeOpTorchModel class needs to implement the $continue() method (and will redirect this responsibility it to LearnerTorchModel).
There are probably a lot of cases to consider, but this is the general idea that at least seems convenient from a user-perspective and I think works in the simple example above.
Note that $continue() should either overwrite the learner's model or modify it in-place, as $predict() should work as usual.
Open questions:
- does this work out?
- How to define the learner's state (regarding the other meta-information like logs etc.)?
- Must the task used during
$continue()be identical to the task used during$train()? (probably not, in this case$continue()should also be related to finetuning of a torch learner.
This is basically equivalent to hotstarting with a slight UI change so I think we should rely on the code for hotstarting
The callbacks also need to be continuable, e.g. some of the learning rate schedulers need to set the last_epoch so they can properly be resumed. Need to handle this!
I think for now we can just offer a t_clbk("resume", weights_path, optim_path, epochs) where one can pass paths to the weights and optimizer as well as the epochs (possibly inferring some of this). There should also be a way to atomatically continue from a checkpoint as obtained by t_clbk("checkpoint").