pytorch-lightning
pytorch-lightning copied to clipboard
restore_training_state before on_fit_start?
Description & Motivation
I need to move some opimizer states to the device of the corresponding grad of the embeddings
I extended the optimizer to do it after super().load_state_dict but _optimizer_to_device(optimizer, self.root_device)
moves them back from cpu to accelerator.
And there is no way to do it in on_fit_start, which was proposed by https://github.com/Lightning-AI/pytorch-lightning/issues/8035, for parameters but this doesn't work with optimizers variables because optimizer state loading happens after on_fit_start while parameters loading happens before on_fit_start.
see also https://github.com/Lightning-AI/pytorch-lightning/issues/3698
Pitch
Move
# hook
if self.state.fn == TrainerFn.FITTING:
call._call_callback_hooks(self, "on_fit_start")
call._call_lightning_module_hook(self, "on_fit_start")
After
# restore optimizers, etc.
log.debug(f"{self.__class__.__name__}: restoring training state")
self._checkpoint_connector.restore_training_state()
Alternatives
Can't think of an alernative solution. If someone knows, let me know.
Additional context
No response
cc @borda