pytorch-lightning icon indicating copy to clipboard operation
pytorch-lightning copied to clipboard

restore_training_state before on_fit_start?

Open lampuiho opened this issue 4 months ago • 0 comments

Description & Motivation

I need to move some opimizer states to the device of the corresponding grad of the embeddings I extended the optimizer to do it after super().load_state_dict but _optimizer_to_device(optimizer, self.root_device) moves them back from cpu to accelerator. And there is no way to do it in on_fit_start, which was proposed by https://github.com/Lightning-AI/pytorch-lightning/issues/8035, for parameters but this doesn't work with optimizers variables because optimizer state loading happens after on_fit_start while parameters loading happens before on_fit_start.

see also https://github.com/Lightning-AI/pytorch-lightning/issues/3698

Pitch

Move

        # hook
        if self.state.fn == TrainerFn.FITTING:
            call._call_callback_hooks(self, "on_fit_start")
            call._call_lightning_module_hook(self, "on_fit_start")

After

        # restore optimizers, etc.
        log.debug(f"{self.__class__.__name__}: restoring training state")
        self._checkpoint_connector.restore_training_state()

Alternatives

Can't think of an alernative solution. If someone knows, let me know.

Additional context

No response

cc @borda

lampuiho avatar Oct 12 '24 03:10 lampuiho