mlr3torch
mlr3torch copied to clipboard
restart learning from most recent checkpoint?
Can you please add some mention of how to use the pt files shown in the example on https://mlr3torch.mlr-org.com/reference/mlr_callback_set.checkpoint.html ? @sebffischer is it possible to start learning in a new R session from one of these files? (I believe it should be possible, because that is the goal of checkpointing, so I expected that should be documented)
Thanks!!
related to https://github.com/mlr-org/mlr3torch/issues/142#issuecomment-2609319544
Yeah, this should definitely be possible. In principle we have a mechanism for this in mlr3, which is called "hotstarting" so we were also discussing implementing this in a "nice" way.
The idea would be to make something like learner$train(task); learner$configure(epochs = 5); learner$continue(tasks) also work. But at least for now we can solve it via the callback, which is easier.
I am a bit busy until next friday, so if you need it before that you would have to go at it yourself. It should not be too difficult though. https://mlr3torch.mlr-org.com/articles/callbacks.html contains a short tutorial on adding a custom callback.
In principle, you need to:
self$ctx$network$load_state_dict(network_params)wherenetwork_paramsis the$state_dict()of the networkself$ctx$optimizer$load_state_dict(optim_state), whereoptim_stateis the$state_dict()of the optimizer.
You might also want to handle loading the state dicts of the callbacks, but I guess this is not absolutely necessary (e.g. if you don't handle this, he history from the previous training run would be "forgotten". There is already some infrastructure in place also for the callbacks ($state_dict() and $load_state_dict()) so it should also be possible.
thanks for the info. I probably don't have time to work on this for a few weeks. for context, I run some experiments on the cluster with a lot of epochs (10000) and some jobs are interrupted before they finish due to time limits. it would be nice to be able to restart these jobs at say 5000 epochs instead of 0.