dvclive icon indicating copy to clipboard operation
dvclive copied to clipboard

`PyTorch Lightning`: Support saving model to `model_file`

Open daavoo opened this issue 4 years ago • 10 comments

To be consistent with other integrations.

daavoo avatar Oct 04 '21 14:10 daavoo

Copy here from PR conversation: LightningModule implements method log, which user call to store everything he wants, and then LightningModule calls logger. Checkpoints are saved at the end of batch/epoch via hooks like on_train_batch_end, etc. Logger doesn't has access to these hooks. To save models we should write callback and encourage users to use it with our logger. But users can use native callbacks and control saving everything himself, and track it with DVC.

We can implement after_save_checkpoint method and save model_file here. But it will duplicate files if user won't disable native checkpointing at all, so we should encourage users to do so.

sirily avatar Oct 04 '21 16:10 sirily

Hi @daavoo @sirily - I've just been bashing my head against this one and would be happy to implement a solution.

Maybe I've got the idea totally backwards, but I'd like to use DVC to version the checkpoint file and recover state from earlier experiments. So far I have:

dvc.yaml

stages:
  train:
    cmd: python train.py
    deps:
    - data/
    - train.py
    outs:
    - ckpts/model.ckpt:
        checkpoint: true
    live:
      logs:
        summary: true
        html: false

train.py

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from dvclive.lightning import DvcLiveLogger

checkpoint_callback = ModelCheckpoint(filename="model", 
                                      dirpath='ckpts',
                                      save_top_k = 1)

dvclive_logger = DvcLiveLogger(
    path='logs',
    resume=True,
    summary=True
)

m = MyModel()

trainer = Trainer(logger=dvclive_logger, callbacks = checkpoint_callback)

trainer.fit(m, ckpt_path='ckpts/model.ckpt') # loads from checkpoint

I encounter the problem that PL doesn't like overwriting files in the ckpt directory, and instead saves new checkpoint files as ckpts/model-v1.ckpt *-v2.ckpt etc., meaning that the new model state is not reflected in dvc.lock

We can implement after_save_checkpoint method and save model_file here. But it will duplicate files if user won't disable native checkpointing at all, so we should encourage users to do so.

@sirily - is your suggestion to use after_save_checkpoint to persist PL's checkpoints into a file that DVC is tracking? Would it fire as intended with PL's enable_checkpointing=False to suppress the native checkpointing?

a-orn avatar Jan 07 '22 01:01 a-orn

Would it fire as intended with PL's enable_checkpointing=False to suppress the native checkpointing?

@a-orn afaik enable_checkpointing just creates a ModelCheckpoint callback with default parameters. If you use ModelCheckpoint callback on your own or some custom callback, you don't have to set this flag. But I can't check it right now, I need a couple of days to return home

I encounter the problem that PL doesn't like overwriting files in the ckpt directory, and instead saves new checkpoint files as ckpts/model-v1.ckpt *-v2.ckpt etc

If you set a metric to monitor, PL will save only one checkpoint. Something like this:

cp = pl.callbacks.model_checkpoint.ModelCheckpoint(
    monitor="train_loss_epoch",
    save_top_k=1,
    dirpath=ckpt_path,
    filename='checkpoint_{epoch}')

sirily avatar Jan 07 '22 03:01 sirily

Thanks, that's good to know. I think it still results in a PL checkpoint named differently to model.ckpt in dvc.yaml, so it might require a more complex stage than the one I've proposed. I think, elsewhere, you proposed using a foreach stage to checkpoint at regular intervals?

EDIT: here

a-orn avatar Jan 07 '22 03:01 a-orn

I think it still results in a PL checkpoint named differently to model.ckpt

With filename='checkpoint' it will be the only one name, won't it? I can't check it right now.

I think, elsewhere, you proposed using a foreach stage to checkpoint at regular intervals?

Yep, I think it's the easiest way. If you have in mind something more beautiful, you are welcome to propose it

sirily avatar Jan 07 '22 03:01 sirily

I recently discovered the experimental 'Checkpointing IO API' for Pytorch Lightning:

import torch, os

from pytorch_lightning import Trainer
from pytorch_lightning.plugins import CheckpointIO
from pytorch_lightning.callbacks import ModelCheckpoint

from dvclive.lightning import DvcLiveLogger

# Override default naming pattern
class CustomCheckpointIO(CheckpointIO):
    def save_checkpoint(self, checkpoint, path):
        if(os.path.exists(model_file)):
            filename = os.path.basename(model_file)
            os.rename(model_file, model_file.replace(filename, 'old.ckpt'))
        torch.save(checkpoint, model_file)

    def load_checkpoint(self, path):
        return torch.load(model_file)

    def remove_checkpoint(self, path):
        filename = os.path.basename(model_file)
        os.remove(model_file.replace(filename, 'old.ckpt'))

# Set path to file checkpointed in dvc.yaml
model_file = 'model.ckpt'

dvclive_logger = DvcLiveLogger(
    path='logs',
    resume=True,
    summary=True
)

m = MyModel()

trainer = Trainer(
    logger=dvclive_logger, 
    plugins=CustomCheckpointIO(), 
    callbacks=ModelCheckpoint()
)

trainer.fit(m, ckpt_path=model_file) # loads from checkpoint

This is starting to get to the workflow I expected. I call dvc exp run and it starts from the last checkpointed epoch until completion or is killed. The checkpoints are overwritten and dvc.lock is updated. When I've finished training, I commit the experiment and the checksum of the final checkpoint file is tracked in git for use in downstream tasks.

@daavoo is this the pattern you had in mind for #140 and #174?

Docs at: https://pytorch-lightning.readthedocs.io/en/stable/advanced/checkpoint_io.html

a-orn avatar Jan 11 '22 05:01 a-orn

I recently discovered the experimental 'Checkpointing IO API' for Pytorch Lightning:

import torch, os

from pytorch_lightning import Trainer
from pytorch_lightning.plugins import CheckpointIO
from pytorch_lightning.callbacks import ModelCheckpoint

from dvclive.lightning import DvcLiveLogger

# Override default naming pattern
class CustomCheckpointIO(CheckpointIO):
    def save_checkpoint(self, checkpoint, path):
        if(os.path.exists(model_file)):
            filename = os.path.basename(model_file)
            os.rename(model_file, model_file.replace(filename, 'old.ckpt'))
        torch.save(checkpoint, model_file)

    def load_checkpoint(self, path):
        return torch.load(model_file)

    def remove_checkpoint(self, path):
        filename = os.path.basename(model_file)
        os.remove(model_file.replace(filename, 'old.ckpt'))

# Set path to file checkpointed in dvc.yaml
model_file = 'model.ckpt'

dvclive_logger = DvcLiveLogger(
    path='logs',
    resume=True,
    summary=True
)

m = MyModel()

trainer = Trainer(
    logger=dvclive_logger, 
    plugins=CustomCheckpointIO(), 
    callbacks=ModelCheckpoint()
)

trainer.fit(m, ckpt_path=model_file) # loads from checkpoint

This is starting to get to the workflow I expected. I call dvc exp run and it starts from the last checkpointed epoch until completion or is killed. The checkpoints are overwritten and dvc.lock is updated. When I've finished training, I commit the experiment and the checksum of the final checkpoint file is tracked in git for use in downstream tasks.

Thanks for the code sample @a-orn ! I need to try it.

@daavoo is this the pattern you had in mind for #140 and #174?

The pattern is to overwrite the model on each save call, like #174 does, and rely on DVC Chekpoints functionality to resume/restart.

Overwritten models are not lost because of DVC checkpoints mechanism so there is no need to store the latest one under old.ckpt.

I think that we should implement what @sirily suggested:

We can implement after_save_checkpoint method and save model_file here. But it will duplicate files if user won't disable native checkpointing at all, so we should encourage users to do so.

And emphasize on docs how to disable native checkpointing in order to prevent duplication

The checkpoint_io plugin makes sense, but it has the downside of requiring the user to pass 2 different things to the trainer instead of just the callback. I think that

daavoo avatar Jan 11 '22 08:01 daavoo

@daavoo

I think that we should implement what @sirily suggested:

I've tried it, but PL doesn't call after_save_checkpoint if there is no checkpoint callback (checkpoint_callback=False in trainer initialization). And if we use checkpoint callback, we can do all naming things there, without any additional code.

import os
import pytorch_lightning as pl
import argparse
from dvclive.lightning import DvcLiveLogger
from tests.test_lightning import LitMNIST

if __name__ == '__main__':
    model = LitMNIST()
    # set checkpoint path
    ckpt_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "checkpoints"))
    # checkpoints will be saved to checkpoints/model.ckpt
    cp = pl.callbacks.model_checkpoint.ModelCheckpoint(
        monitor="train_loss_epoch",
        save_top_k=1,
        dirpath=ckpt_path,
        filename='model')

    # Set path to file checkpointed in dvc.yaml
    model_file = os.path.join(ckpt_path, "model.ckpt")

    # init logger
    dvclive_logger = DvcLiveLogger()
    # set None if it is the first run
    resume_from_checkpoint = None if not os.path.exists(model_file) else model_file
    trainer = pl.Trainer(
        logger=dvclive_logger, max_epochs=10,
        callbacks=[cp]
    )
    trainer.fit(model, ckpt_path=resume_from_checkpoint)

with dvc.yaml like this:

stages:
  train:
    cmd: python train.py
    outs:
      - checkpoints/model.ckpt:
          checkpoint: true

This is the workflow user expects, I guess. dvc exp run begins an experiment and resumes it if possible, and model.ckpt is tracked by dvc. But we still have to use ModelCheckpoint callback. I don't see any ways to use only logger for saving checkpoints properly.

There is one problem with this way: DVC always tries to resume training, even if you change some parameter or track .py file as dependency and change it. If I use plugin as @a-orn suggested, I have this behavior either. And I have to delete old experiments or clear dvc.lock file. The way from December community gems seems more convinient to me.

sirily avatar Jan 16 '22 12:01 sirily

@sirily If I understand you correctly, your solution from the community gems then solves two different issues you have:

  1. Unlike other frameworks, PyTorch Lightning logging does not expose hooks to methods like on_train_batch_end, making it challenging to save the model in a dvclive integration.
  2. Checkpoints in DVC always resume training from the previous checkpoint. There is an ongoing discussion and proposed behavior change in https://github.com/iterative/dvc/discussions/6104#discussioncomment-1838965. Would that proposal simplify your workflow? Is that the behavior you would expect?

dberenbaum avatar Jan 18 '22 18:01 dberenbaum

@dberenbaum yes, I expect, that DVC resume an experiment only after fail. If it is interrupted by Ctrl+C or completed, I expect, that DVC start a new experiment (when smth changed of course). And if the experiment was failed, but some parameter has changed, I expect, that DVC start a new experiment either. With such workflow it will be possible to use simple code from my previous post in this thread with native PL checkpoint callback. But one still has to use PL callback with dvclive logger, and we can't implement saving inside the logger.

sirily avatar Jan 18 '22 18:01 sirily