pytorch-lightning icon indicating copy to clipboard operation
pytorch-lightning copied to clipboard

CheckpointConnector should support loading with `strict=False`

Open Ir1d opened this issue 3 years ago • 10 comments

🐛 Bug

pytorch_lightning/trainer/connectors/checkpoint_connector.py should support loading with strict=False

To Reproduce

Expected behavior

Say if you add a new torchmetrics module into your LightningModule, you wont be able to validate / test existing checkpoints because new keys are added.

This happens to all Trainer functions involving ckpt_path as parameter, such as validate(model=None, dataloaders=None, ckpt_path=None, verbose=True, datamodule=None)

Environment

  • PyTorch Lightning Version (e.g., 1.5.0):
  • PyTorch Version (e.g., 1.10):
  • Python version (e.g., 3.9):
  • OS (e.g., Linux):
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • How you installed PyTorch (conda, pip, source):
  • If compiling from source, the output of torch.__config__.show():
  • Any other relevant information:

Additional context

cc @borda @tchaton @justusschock @awaelchli @ananthsub @ninginthecloud @rohitgr7 @otaj

Ir1d avatar Jun 07 '22 15:06 Ir1d

This is a fair ask that has come up a few times in the past. Users want to configure strict=True here: https://github.com/PyTorchLightning/pytorch-lightning/blob/83436ee3dfd0d4079e0f8e704ba76aca672af19d/pytorch_lightning/strategies/strategy.py#L317-L322

I can think of two solutions: (a): Route strict=bool all the way from the trainer entrypoints to the checkpoint connector, as done for ckpt_path.

trainer.validate(..., ckpt_path=..., strict=False)

(b): Make this configurable in the strategy with an attribute:

trainer = Trainer(...)
trainer.strategy.strict_loading = False
trainer.validate(...)

Thoughts @awaelchli @justusschock?

carmocca avatar Jun 07 '22 16:06 carmocca

I believe this is critical. If you use PyTorch-lightning for iterating your model design. Everything breaks when you make modifications to the module.

Ir1d avatar Jun 07 '22 17:06 Ir1d

tbh, I like neither of those approaches @carmocca 🙈

Usually you only want to load with strict=False if you intend to make changes/already made some changes, therefore for me it is a model property (similar to automatic_optimization and truncated_bptt_steps). If you make it a static attribute this can also be checked before instantiation incase LM.load_from_checkpoint is used.

justusschock avatar Jun 07 '22 20:06 justusschock

Sounds good to me.

@Ir1d Would you be interested in contributing this feature?

carmocca avatar Jun 08 '22 11:06 carmocca

sorry, I dont have much time to work on this @carmocca

Ir1d avatar Jun 09 '22 00:06 Ir1d

Hmm, actually I had modified the Pytorch lightning code to allow PyTorch lightning CLI to allow strict=False for my need and it works. Maybe I can contribute a PR these two days according to PyTorch lightning PR standard.

stevenwudi avatar Jun 19 '22 15:06 stevenwudi

I like @justusschock's suggestion for making it a static attribute on the class.

awaelchli avatar Jun 21 '22 09:06 awaelchli

Hi, this is urgently still an issue! What is the workaround? Particularly for PL < 2 users

Also sidenote, why is the trainer even loading the state dict from the checkpoint? trainer.fit takes a model instance as input, which itself can be loaded from a checkpoint, so its redundant to load it again.

thesofakillers avatar Jul 14 '23 14:07 thesofakillers

For anyone looking for a temporary hacky workaround, here's what I did. My workflow was

# the state dict in ckpt does not contain the pretrained model, because we train with that frozen, so no point in saving it
model = ModelClass.load_from_checkpoint(ckpt)
pretrained_model = load_pretrained_model()
# adds the pt_model module and freezes it
model.set_pt_model(pretrained_model)
# this was causing issues because now my model.state_dict() had pt_model keys, while the ckpt state dict did not
trainer.fit(model, datamodule, ckpt_path=ckpt)

In ModelClass, we implement on_load_checkpoint in the following way

def on_load_checkpoint(self, checkpoint: Dict[str, Any]):
  """
  The pt_model is trained separately, so we already have access to its
  checkpoint and load it separately with `self.set_pt_model`.
  
  However, the PL Trainer is strict about
  checkpoint loading (not configurable), so it expects the loaded state_dict
  to match exactly the keys in the model state_dict.
  
  So, when loading the checkpoint, before matching keys, we add all pt_model keys
  from self.state_dict() to the checkpoint state dict, so that they match
  """
  for key in self.state_dict().keys():
      if key.startswith("pt_model"):
          checkpoint["state_dict"][key] = self.state_dict()[key]

The lack of comments here, and the relatively common use case of using frozen PT models seems to suggest to me that I'm doing something wrong. Are people saving the frozen PT weights along with their model? Because I was deleting them to save space.

thesofakillers avatar Jul 14 '23 15:07 thesofakillers

So still no official support in lightning.Trainer.fit for ckpt loading strict=False?

ecolss avatar Jan 21 '24 10:01 ecolss