pytorch-lightning icon indicating copy to clipboard operation
pytorch-lightning copied to clipboard

Inconcistency in loading from checkpoint in LightningCLI

Open Northo opened this issue 8 months ago • 1 comments

Bug description

When using a checkpoint in LightningCLI, the model is first instantiated and then the checkpoint is loaded by supplying it to the Trainer's method's ckpt_path argument.

The problem is that hyperparameters in the checkpoint are not used when instantiating the model, and thus when allocating tensors, which can cause checkpoint loading to fail if tensor sizes do not match. Furthermore, if there is complicated instantiation logic in the model, this may lead to other silent bugs or failures.

This was first raised as a discussion in #20715

What version are you seeing the problem on?

v2.5

How to reproduce the bug

Here is a minimal example, where the predict method is used. We modify the out_dim in fit, so that the last layer has a different size, causing loading in predict to fail.

# cli.py
from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.demos.boring_classes import DemoModel, BoringDataModule

class DemoModelWithHyperparameters(DemoModel):
    def __init__(self, *args, **kwargs):
        self.save_hyperparameters()
        super().__init__(*args, **kwargs)

def cli_main():
    cli = LightningCLI(DemoModelWithHyperparameters, BoringDataModule)

if __name__ == "__main__":
    cli_main()

and then run

$ python src/lightning_cli_load_checkpoint/cli.py fit --trainer.max_epochs 1 --model.out_dim 2
$ python src/lightning_cli_load_checkpoint/cli.py predict --ckpt_path <path_to_checkpoint>

Error messages and logs

Restoring states from the checkpoint path at lightning_logs/version_23/checkpoints/epoch=0-step=64.ckpt
Traceback (most recent call last):
  File ".../lightning_cli_load_checkpoint/src/lightning_cli_load_checkpoint/cli.py", line 20, in <module>
    cli_main()
  File ".../lightning_cli_load_checkpoint/src/lightning_cli_load_checkpoint/cli.py", line 16, in cli_main
    cli = MyLightningCLI(DemoModelWithHyperparameters, datamodule_class=BoringDataModule)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/cli.py", line 398, in __init__
    self._run_subcommand(self.subcommand)
  File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/cli.py", line 708, in _run_subcommand
    fn(**fn_kwargs)
  File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 887, in predict
    return call._call_and_handle_interrupt(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 48, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 928, in _predict_impl
    results = self._run(model, ckpt_path=ckpt_path)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 981, in _run
    self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path)
  File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 409, in _restore_modules_and_callbacks
    self.restore_model()
  File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 286, in restore_model
    self.trainer.strategy.load_model_state_dict(
  File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py", line 372, in load_model_state_dict
    self.lightning_module.load_state_dict(checkpoint["state_dict"], strict=strict)
  File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 2581, in load_state_dict
    raise RuntimeError(
RuntimeError: Error(s) in loading state_dict for DemoModelWithHyperparameters:
	size mismatch for l1.weight: copying a param with shape torch.Size([2, 32]) from checkpoint, the shape in current model is torch.Size([10, 32]).
	size mismatch for l1.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([10]).

Expected behavior

I'd expect the loading to respect the checkpoint's arguments. In other words, while the current implementation roughly follows this logic:

model = Model(**cli_args)
Trainer().predict(model, data, ckpt_path=ckpt_path)

I'd expect it to be closer to

model = Model.load_from_checkpoint(ckpt_path, **cli_args)
Trainer().predict(model, data, ckpt_path=ckpt_path)

Environment

Current environment
  • CUDA:
    • GPU: None
    • available: False
    • version: None
  • Lightning:
    • lightning: 2.5.1
    • lightning-cli-load-checkpoint: 0.1.0
    • lightning-utilities: 0.14.3
    • pytorch-lightning: 2.5.1
    • torch: 2.6.0
    • torchmetrics: 1.7.1
  • Packages:
    • aiohappyeyeballs: 2.6.1
    • aiohttp: 3.11.16
    • aiosignal: 1.3.2
    • antlr4-python3-runtime: 4.9.3
    • attrs: 25.3.0
    • autocommand: 2.2.2
    • backports.tarfile: 1.2.0
    • contourpy: 1.3.1
    • cycler: 0.12.1
    • docstring-parser: 0.16
    • filelock: 3.18.0
    • fonttools: 4.57.0
    • frozenlist: 1.5.0
    • fsspec: 2025.3.2
    • hydra-core: 1.3.2
    • idna: 3.10
    • importlib-metadata: 8.0.0
    • importlib-resources: 6.5.2
    • inflect: 7.3.1
    • jaraco.collections: 5.1.0
    • jaraco.context: 5.3.0
    • jaraco.functools: 4.0.1
    • jaraco.text: 3.12.1
    • jinja2: 3.1.6
    • jsonargparse: 4.38.0
    • kiwisolver: 1.4.8
    • lightning: 2.5.1
    • lightning-cli-load-checkpoint: 0.1.0
    • lightning-utilities: 0.14.3
    • markdown-it-py: 3.0.0
    • markupsafe: 3.0.2
    • matplotlib: 3.10.1
    • mdurl: 0.1.2
    • more-itertools: 10.3.0
    • mpmath: 1.3.0
    • multidict: 6.4.3
    • networkx: 3.4.2
    • numpy: 2.2.4
    • omegaconf: 2.3.0
    • packaging: 24.2
    • pillow: 11.2.1
    • platformdirs: 4.2.2
    • propcache: 0.3.1
    • protobuf: 6.30.2
    • pygments: 2.19.1
    • pyparsing: 3.2.3
    • python-dateutil: 2.9.0.post0
    • pytorch-lightning: 2.5.1
    • pyyaml: 6.0.2
    • rich: 13.9.4
    • setuptools: 78.1.0
    • six: 1.17.0
    • sympy: 1.13.1
    • tensorboardx: 2.6.2.2
    • tomli: 2.0.1
    • torch: 2.6.0
    • torchmetrics: 1.7.1
    • tqdm: 4.67.1
    • typeguard: 4.3.0
    • typeshed-client: 2.7.0
    • typing-extensions: 4.13.2
    • wheel: 0.45.1
    • yarl: 1.19.0
    • zipp: 3.19.2
  • System:
    • OS: Darwin
    • architecture:
      • 64bit
    • processor: arm
    • python: 3.12.7
    • release: 24.4.0
    • version: Darwin Kernel Version 24.4.0: Fri Apr 11 18:33:47 PDT 2025; root:xnu-11417.101.15~117/RELEASE_ARM64_T6000

Northo avatar May 06 '25 06:05 Northo

After thinking a bit I think that implementing this is not trivial. The main difficulty comes from the argument links. Regardless on how the code is changed, it should work well when argument links are added to the parser. I am not sure if I am aware of all the details that need to be considered. And could be that some are only noticed while implementing. For now I would propose to do the following:

  1. Behavior only changes when LightningCLI is working in subcommands mode, i.e. run=True.
  2. Before instantiation and running, check whether ckpt_path is passed to the subcommand. Note that in principle ckpt_path could be defined as a command line argument, in a config file or as an environment variable. So a simple solution could be to parse, check if ckpt_path set, and if set, a second parse would be needed.
  3. If ckpt_path is set, use torch.load to read the checkpoint, and check if hyperparameters are included (i.e. save_hyperparameters was used).
  4. If hyperparameters are included, remove the keys that correspond to link targets (both applied on parse and on instantiate). Unfortunately, right now there is no official way (jsonargparse public API) to know which keys are link targets.
  5. After removing link targets, parse again, but modifying the args such that right after the subcommand (e.g. predict) and before all other arguments, there is a new --config option with value the modified hyperparameters from the checkpoint.
  6. Continue the normal flow which would instantiate classes and then run the trainer method.

I need to figure out what to do about point 4. Most likely new feature in jsonargparse is needed.

mauvilsa avatar May 08 '25 09:05 mauvilsa

Hi is there any progress on this? I think this is a very big issue. How can the CLI even be useful if we can't run a simple a fit and test?

python main.py fit --model.out_dim 2
python main.py test --ckpt_path path/to/ckpt

marawangamal avatar Jul 06 '25 19:07 marawangamal

Hey @Borda can you confirm with this bug and assign me to work on it. Thanks😊

fnhirwa avatar Jul 09 '25 11:07 fnhirwa

I was looking into "remove the keys that correspond to link targets" from my comment above, and it seems it isn't an issue as I thought. This means that the implementation is easy and there is no need for changes in jsonargparse. I have created #21116 with the fix.

mauvilsa avatar Aug 25 '25 03:08 mauvilsa

Since review+merge+release can take a while, below is a temporal patch for people that don't want to wait.

class MyCLI(LightningCLI):
    def parse_arguments(self, *args, **kwargs):
        if hasattr(LightningCLI, "_parse_ckpt_path"):
            raise RuntimeError("_parse_ckpt_path already available in lightning, this patch is no longer needed")
        super().parse_arguments(*args, **kwargs)
        self._parse_ckpt_path()

    def _parse_ckpt_path(self):
        """If a checkpoint path is given, parse the hyperparameters from the checkpoint and update the config."""
        if not self.config.get("subcommand"):
            return
        ckpt_path = self.config[self.config.subcommand].get("ckpt_path")
        if ckpt_path and Path(ckpt_path).is_file():
            ckpt = torch.load(ckpt_path, weights_only=True, map_location="cpu")
            hparams = ckpt.get("hyper_parameters", {})
            hparams.pop("_instantiator", None)
            if hparams:
                hparams = {self.config.subcommand: {"model": hparams}}
                self.config = self.parser.parse_object(hparams, self.config)

mauvilsa avatar Aug 25 '25 03:08 mauvilsa