lightly icon indicating copy to clipboard operation
lightly copied to clipboard

Add documentation and helper functions to store and retrieve checkpoints

Open IgorSusmelj opened this issue 4 years ago • 5 comments

We have a model zoo but no documentation on how to load the models from it. We also have the PyTorch Lightning wrapper for training the models but no information on how to use the models afterward.

It would be great to provide the following (documentation and if it makes sense helper functions):

How to load a model that has been saved by PyTorch Lightning (e.g. I just want to get the ResNet backbone from it)?

# run lightly from CLI or using the PyTorch Lightning wrapper
...
# now you should have a folder with a checkpoint
# let's load the checkpoint in another script (e.g. to do transfer learning)
ckpt = torch.load('my_lightning_checkpoint.pth')
my_resnet = lightly.models.ResNetGenerator()

# load checkpoint state dict to my_resnet
... # TODO: how to load ckpt state dict into my_resnet

How to manually save and load a model?

# simple example of storing and reading a state dict
backbone = lightly.models.ResNetGenerator()
simclr_model = lightly.models.SimCLR(backbone)

# save weights from backbone
torch.save({'model': simclr_model.state_dict()}, 'my_checkpoint.pth')

# load the backbone later (can be in another script)
ckpt = torch.load('my_checkpoint.pth'')
backbone = lightly.models.ResNetGenerator()
backbone.load_state_dict(ckpt['model'])

IgorSusmelj avatar Feb 13 '21 14:02 IgorSusmelj

Is this issue free to work on?

Nike682631 avatar Sep 17 '21 17:09 Nike682631

Yes, it is. However, parts may change with the ongoing refactoring so I'd recommend to work on something else atm.

philippmwirth avatar Sep 20 '21 06:09 philippmwirth

Sure

On Mon, 20 Sep 2021, 12:27 Philipp Wirth, @.***> wrote:

Yes, it is. However, parts may change with the ongoing refactoring so I'd recommend to work on something else atm.

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/lightly-ai/lightly/issues/127#issuecomment-922675797, or unsubscribe https://github.com/notifications/unsubscribe-auth/AJUIFVBJM2GUNHSNGHTBLC3UC3LL3ANCNFSM4XSGU2VA .

Nike682631 avatar Sep 20 '21 06:09 Nike682631

Has this been added to the documentation?

shree-lily avatar Sep 21 '22 21:09 shree-lily

Hi @shree-lily, we sadly did not add this to the documentation yet.

You can still use either pytorch or pytorch lightning to save and load checkpoints:

  • https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-model-for-inference
  • https://pytorch-lightning.readthedocs.io/en/stable/common/checkpointing_basic.html

Taking our example code for the SimCLR model here: https://docs.lightly.ai/examples/simclr.html You can save and load the model as follows:

...

# saving
model = SimCLR(backbone)
torch.save(model.state_dict(), 'simclr_model.ckpt')

# loading
model = SimCLR(backbone)
model.load_state_dict(torch.load('simclr_model.ckpt'))

It is common to drop the projection heads for downstream or inference tasks and only use the backbone instead. You can get the backbone with the model.backbone attribute.

guarin avatar Sep 22 '22 06:09 guarin