pytorch-lightning
pytorch-lightning copied to clipboard
CheckpointConnector should support loading with `strict=False`
🐛 Bug
pytorch_lightning/trainer/connectors/checkpoint_connector.py should support loading with strict=False
To Reproduce
Expected behavior
Say if you add a new torchmetrics module into your LightningModule, you wont be able to validate / test existing checkpoints because new keys are added.
This happens to all Trainer functions involving ckpt_path as parameter, such as validate(model=None, dataloaders=None, ckpt_path=None, verbose=True, datamodule=None)
Environment
- PyTorch Lightning Version (e.g., 1.5.0):
- PyTorch Version (e.g., 1.10):
- Python version (e.g., 3.9):
- OS (e.g., Linux):
- CUDA/cuDNN version:
- GPU models and configuration:
- How you installed PyTorch (
conda,pip, source): - If compiling from source, the output of
torch.__config__.show(): - Any other relevant information:
Additional context
cc @borda @tchaton @justusschock @awaelchli @ananthsub @ninginthecloud @rohitgr7 @otaj
This is a fair ask that has come up a few times in the past. Users want to configure strict=True here: https://github.com/PyTorchLightning/pytorch-lightning/blob/83436ee3dfd0d4079e0f8e704ba76aca672af19d/pytorch_lightning/strategies/strategy.py#L317-L322
I can think of two solutions:
(a): Route strict=bool all the way from the trainer entrypoints to the checkpoint connector, as done for ckpt_path.
trainer.validate(..., ckpt_path=..., strict=False)
(b): Make this configurable in the strategy with an attribute:
trainer = Trainer(...)
trainer.strategy.strict_loading = False
trainer.validate(...)
Thoughts @awaelchli @justusschock?
I believe this is critical. If you use PyTorch-lightning for iterating your model design. Everything breaks when you make modifications to the module.
tbh, I like neither of those approaches @carmocca 🙈
Usually you only want to load with strict=False if you intend to make changes/already made some changes, therefore for me it is a model property (similar to automatic_optimization and truncated_bptt_steps). If you make it a static attribute this can also be checked before instantiation incase LM.load_from_checkpoint is used.
Sounds good to me.
@Ir1d Would you be interested in contributing this feature?
sorry, I dont have much time to work on this @carmocca
Hmm, actually I had modified the Pytorch lightning code to allow PyTorch lightning CLI to allow strict=False for my need and it works. Maybe I can contribute a PR these two days according to PyTorch lightning PR standard.
I like @justusschock's suggestion for making it a static attribute on the class.
Hi, this is urgently still an issue! What is the workaround? Particularly for PL < 2 users
Also sidenote, why is the trainer even loading the state dict from the checkpoint? trainer.fit takes a model instance as input, which itself can be loaded from a checkpoint, so its redundant to load it again.
For anyone looking for a temporary hacky workaround, here's what I did. My workflow was
# the state dict in ckpt does not contain the pretrained model, because we train with that frozen, so no point in saving it
model = ModelClass.load_from_checkpoint(ckpt)
pretrained_model = load_pretrained_model()
# adds the pt_model module and freezes it
model.set_pt_model(pretrained_model)
# this was causing issues because now my model.state_dict() had pt_model keys, while the ckpt state dict did not
trainer.fit(model, datamodule, ckpt_path=ckpt)
In ModelClass, we implement on_load_checkpoint in the following way
def on_load_checkpoint(self, checkpoint: Dict[str, Any]):
"""
The pt_model is trained separately, so we already have access to its
checkpoint and load it separately with `self.set_pt_model`.
However, the PL Trainer is strict about
checkpoint loading (not configurable), so it expects the loaded state_dict
to match exactly the keys in the model state_dict.
So, when loading the checkpoint, before matching keys, we add all pt_model keys
from self.state_dict() to the checkpoint state dict, so that they match
"""
for key in self.state_dict().keys():
if key.startswith("pt_model"):
checkpoint["state_dict"][key] = self.state_dict()[key]
The lack of comments here, and the relatively common use case of using frozen PT models seems to suggest to me that I'm doing something wrong. Are people saving the frozen PT weights along with their model? Because I was deleting them to save space.
So still no official support in lightning.Trainer.fit for ckpt loading strict=False?