lightly
lightly copied to clipboard
Add documentation and helper functions to store and retrieve checkpoints
We have a model zoo but no documentation on how to load the models from it. We also have the PyTorch Lightning wrapper for training the models but no information on how to use the models afterward.
It would be great to provide the following (documentation and if it makes sense helper functions):
How to load a model that has been saved by PyTorch Lightning (e.g. I just want to get the ResNet backbone from it)?
# run lightly from CLI or using the PyTorch Lightning wrapper
...
# now you should have a folder with a checkpoint
# let's load the checkpoint in another script (e.g. to do transfer learning)
ckpt = torch.load('my_lightning_checkpoint.pth')
my_resnet = lightly.models.ResNetGenerator()
# load checkpoint state dict to my_resnet
... # TODO: how to load ckpt state dict into my_resnet
How to manually save and load a model?
# simple example of storing and reading a state dict
backbone = lightly.models.ResNetGenerator()
simclr_model = lightly.models.SimCLR(backbone)
# save weights from backbone
torch.save({'model': simclr_model.state_dict()}, 'my_checkpoint.pth')
# load the backbone later (can be in another script)
ckpt = torch.load('my_checkpoint.pth'')
backbone = lightly.models.ResNetGenerator()
backbone.load_state_dict(ckpt['model'])
Is this issue free to work on?
Yes, it is. However, parts may change with the ongoing refactoring so I'd recommend to work on something else atm.
Sure
On Mon, 20 Sep 2021, 12:27 Philipp Wirth, @.***> wrote:
Yes, it is. However, parts may change with the ongoing refactoring so I'd recommend to work on something else atm.
— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/lightly-ai/lightly/issues/127#issuecomment-922675797, or unsubscribe https://github.com/notifications/unsubscribe-auth/AJUIFVBJM2GUNHSNGHTBLC3UC3LL3ANCNFSM4XSGU2VA .
Has this been added to the documentation?
Hi @shree-lily, we sadly did not add this to the documentation yet.
You can still use either pytorch or pytorch lightning to save and load checkpoints:
- https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-model-for-inference
- https://pytorch-lightning.readthedocs.io/en/stable/common/checkpointing_basic.html
Taking our example code for the SimCLR model here: https://docs.lightly.ai/examples/simclr.html You can save and load the model as follows:
...
# saving
model = SimCLR(backbone)
torch.save(model.state_dict(), 'simclr_model.ckpt')
# loading
model = SimCLR(backbone)
model.load_state_dict(torch.load('simclr_model.ckpt'))
It is common to drop the projection heads for downstream or inference tasks and only use the backbone instead. You can get the backbone with the model.backbone
attribute.